From 3dd915e2fb34d39a7ea3f7d1520cd0222859acc7 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Mon, 12 Jul 2021 20:39:29 +0100 Subject: [PATCH] optimiser --- .../java/space/kscience/kmath/noa/JNoa.java | 10 +++++ .../space/kscience/kmath/noa/algebras.kt | 21 ++++++---- .../kotlin/space/kscience/kmath/noa/optim.kt | 25 ++++++++++++ .../resources/space_kscience_kmath_noa_JNoa.h | 40 +++++++++++++++++++ 4 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index b9afd283b..957387867 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -307,4 +307,14 @@ class JNoa { public static native long getModuleBuffer(long jitModuleHandle, String name); public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle); + + public static native long adamOptim(long jitModuleHandle, double learningRate); + + public static native void disposeAdamOptim(long adamOptHandle); + + public static native void stepAdamOptim(long adamOptHandle); + + public static native void zeroGradAdamOptim(long adamOptHandle); + + public static native void swapTensors(long lhsHandle, long rhsHandle); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 14580613b..baf4b99f9 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -5,7 +5,6 @@ package space.kscience.kmath.noa -import com.sun.security.auth.module.JndiLoginModule import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra @@ -143,22 +142,25 @@ protected constructor(protected val scope: NoaScope) : public abstract fun loadJitModule(path: String, device: Device): NoaJitModule public fun NoaJitModule.forward(parameters: Tensor): TensorType = - wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle)) + wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle)) public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = - JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle) + JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle) public fun NoaJitModule.getParameter(name: String): TensorType = - wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) + wrap(JNoa.getModuleParameter(jitModuleHandle, name)) public fun NoaJitModule.setParameter(name: String, parameter: Tensor): Unit = - JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle) + JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle) public fun NoaJitModule.getBuffer(name: String): TensorType = - wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) + wrap(JNoa.getModuleParameter(jitModuleHandle, name)) public fun NoaJitModule.setBuffer(name: String, buffer: Tensor): Unit = - JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle) + JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle) + + public infix fun TensorType.swap(other: TensorType): Unit = + JNoa.swapTensors(tensorHandle, other.tensorHandle) } @@ -315,7 +317,10 @@ protected constructor(scope: NoaScope) : wrap(JNoa.tensorGrad(tensorHandle)) public fun NoaJitModule.train(status: Boolean): Unit = - JNoa.trainMode(this.jitModuleHandle, status) + JNoa.trainMode(jitModuleHandle, status) + + public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser = + AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate)) } public sealed class NoaDoubleAlgebra diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt new file mode 100644 index 000000000..bf91968dc --- /dev/null +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt @@ -0,0 +1,25 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.noa + +import space.kscience.kmath.noa.memory.NoaResource +import space.kscience.kmath.noa.memory.NoaScope + +internal typealias OptimiserHandle = Long + +public abstract class NoaOptimiser +internal constructor(scope: NoaScope) : NoaResource(scope) { + public abstract fun step(): Unit + public abstract fun zeroGrad(): Unit +} + +public class AdamOptimiser +internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle) + : NoaOptimiser(scope) { + override fun dispose(): Unit = JNoa.disposeAdamOptim(optimiserHandle) + override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle) + override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle) +} \ No newline at end of file diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index 0254efa6d..fedf394fa 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1159,6 +1159,46 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer (JNIEnv *, jclass, jlong, jstring, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: adamOptim + * Signature: (JD)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_adamOptim + (JNIEnv *, jclass, jlong, jdouble); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: disposeAdamOptim + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeAdamOptim + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: stepAdamOptim + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_stepAdamOptim + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: zeroGradAdamOptim + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim + (JNIEnv *, jclass, jlong); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: swapTensors + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors + (JNIEnv *, jclass, jlong, jlong); + #ifdef __cplusplus } #endif