From 102987104793b8d18691eb34b8ad665b93a000dc Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 12 Jul 2021 17:18:02 +0100 Subject: [PATCH] data for jit modules --- .../java/space/kscience/kmath/noa/JNoa.java | 8 +++++ .../space/kscience/kmath/noa/algebras.kt | 17 ++++++++-- .../resources/space_kscience_kmath_noa_JNoa.h | 32 +++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 2e96ed03a..b9afd283b 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -299,4 +299,12 @@ class JNoa { public static native long forwardPass(long jitModuleHandle, long tensorHandle); public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle); + + public static native long getModuleParameter(long jitModuleHandle, String name); + + public static native void setModuleParameter(long jitModuleHandle, String name, long tensorHandle); + + public static native long getModuleBuffer(long jitModuleHandle, String name); + + public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 47c18f8cf..14580613b 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -142,12 +142,24 @@ protected constructor(protected val scope: NoaScope) : public abstract fun loadJitModule(path: String, device: Device): NoaJitModule - public fun NoaJitModule.forward(parameters: TensorType): TensorType = - wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensorHandle)) + public fun NoaJitModule.forward(parameters: Tensor): TensorType = + wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle)) public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle) + public fun NoaJitModule.getParameter(name: String): TensorType = + wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) + + public fun NoaJitModule.setParameter(name: String, parameter: Tensor): Unit = + JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle) + + public fun NoaJitModule.getBuffer(name: String): TensorType = + wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) + + public fun NoaJitModule.setBuffer(name: String, buffer: Tensor): Unit = + JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle) + } public sealed class NoaPartialDivisionAlgebra> @@ -304,7 +316,6 @@ protected constructor(scope: NoaScope) : public fun NoaJitModule.train(status: Boolean): Unit = JNoa.trainMode(this.jitModuleHandle, status) - } public sealed class NoaDoubleAlgebra diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index 12f2685ba..0254efa6d 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1127,6 +1127,38 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign (JNIEnv *, jclass, jlong, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: getModuleParameter + * Signature: (JLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleParameter + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setModuleParameter + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleParameter + (JNIEnv *, jclass, jlong, jstring, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: getModuleBuffer + * Signature: (JLjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: setModuleBuffer + * Signature: (JLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer + (JNIEnv *, jclass, jlong, jstring, jlong); + #ifdef __cplusplus } #endif