optimiser

This commit is contained in:
Roland Grinis 2021-07-12 20:39:29 +01:00
parent 1029871047
commit 3dd915e2fb
4 changed files with 88 additions and 8 deletions

View File

@ -307,4 +307,14 @@ class JNoa {
public static native long getModuleBuffer(long jitModuleHandle, String name); public static native long getModuleBuffer(long jitModuleHandle, String name);
public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle); public static native void setModuleBuffer(long jitModuleHandle, String name, long tensorHandle);
public static native long adamOptim(long jitModuleHandle, double learningRate);
public static native void disposeAdamOptim(long adamOptHandle);
public static native void stepAdamOptim(long adamOptHandle);
public static native void zeroGradAdamOptim(long adamOptHandle);
public static native void swapTensors(long lhsHandle, long rhsHandle);
} }

View File

@ -5,7 +5,6 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import com.sun.security.auth.module.JndiLoginModule
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
@ -143,22 +142,25 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun loadJitModule(path: String, device: Device): NoaJitModule public abstract fun loadJitModule(path: String, device: Device): NoaJitModule
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType = public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
wrap(JNoa.forwardPass(this.jitModuleHandle, parameters.tensor.tensorHandle)) wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
JNoa.forwardPassAssign(this.jitModuleHandle, parameters.tensorHandle) JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle)
public fun NoaJitModule.getParameter(name: String): TensorType = public fun NoaJitModule.getParameter(name: String): TensorType =
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) wrap(JNoa.getModuleParameter(jitModuleHandle, name))
public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit = public fun NoaJitModule.setParameter(name: String, parameter: Tensor<T>): Unit =
JNoa.setModuleParameter(this.jitModuleHandle, name, parameter.tensor.tensorHandle) JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
public fun NoaJitModule.getBuffer(name: String): TensorType = public fun NoaJitModule.getBuffer(name: String): TensorType =
wrap(JNoa.getModuleParameter(this.jitModuleHandle, name)) wrap(JNoa.getModuleParameter(jitModuleHandle, name))
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit = public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
JNoa.setModuleBuffer(this.jitModuleHandle, name, buffer.tensor.tensorHandle) JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
public infix fun TensorType.swap(other: TensorType): Unit =
JNoa.swapTensors(tensorHandle, other.tensorHandle)
} }
@ -315,7 +317,10 @@ protected constructor(scope: NoaScope) :
wrap(JNoa.tensorGrad(tensorHandle)) wrap(JNoa.tensorGrad(tensorHandle))
public fun NoaJitModule.train(status: Boolean): Unit = public fun NoaJitModule.train(status: Boolean): Unit =
JNoa.trainMode(this.jitModuleHandle, status) JNoa.trainMode(jitModuleHandle, status)
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
} }
public sealed class NoaDoubleAlgebra public sealed class NoaDoubleAlgebra

View File

@ -0,0 +1,25 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
import space.kscience.kmath.noa.memory.NoaResource
import space.kscience.kmath.noa.memory.NoaScope
internal typealias OptimiserHandle = Long
public abstract class NoaOptimiser
internal constructor(scope: NoaScope) : NoaResource(scope) {
public abstract fun step(): Unit
public abstract fun zeroGrad(): Unit
}
public class AdamOptimiser
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
: NoaOptimiser(scope) {
override fun dispose(): Unit = JNoa.disposeAdamOptim(optimiserHandle)
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
}

View File

@ -1159,6 +1159,46 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_getModuleBuffer
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_setModuleBuffer
(JNIEnv *, jclass, jlong, jstring, jlong); (JNIEnv *, jclass, jlong, jstring, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: adamOptim
* Signature: (JD)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_adamOptim
(JNIEnv *, jclass, jlong, jdouble);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: disposeAdamOptim
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_disposeAdamOptim
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: stepAdamOptim
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_stepAdamOptim
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: zeroGradAdamOptim
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
(JNIEnv *, jclass, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: swapTensors
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
(JNIEnv *, jclass, jlong, jlong);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif