add options for optimizers

This commit is contained in:
Anastasia Golovina 2022-05-09 18:22:05 +03:00
parent b288c4ce59
commit 3c92bfda59
3 changed files with 65 additions and 20 deletions

View File

@ -315,8 +315,9 @@ class JNoa {
public static native void stepAdamOptim(long adamOptHandle); public static native void stepAdamOptim(long adamOptHandle);
public static native void zeroGradAdamOptim(long adamOptHandle); public static native void zeroGradAdamOptim(long adamOptHandle);
public static native long rmsOptim(long jitModuleHandle, double learningRate); public static native long rmsOptim(long jitModuleHandle, double learningRate, double alpha,
double eps, double weight_decay, double momentum, boolean centered);
public static native void disposeRmsOptim(long rmsOptHandle); public static native void disposeRmsOptim(long rmsOptHandle);
@ -324,7 +325,8 @@ class JNoa {
public static native void zeroGradRmsOptim(long rmsOptHandle); public static native void zeroGradRmsOptim(long rmsOptHandle);
public static native long adamWOptim(long jitModuleHandle, double learningRate); public static native long adamWOptim(long jitModuleHandle, double learningRate, double beta1,
double beta2, double eps, double weight_decay, boolean amsgrad);
public static native void disposeAdamWOptim(long adamWOptHandle); public static native void disposeAdamWOptim(long adamWOptHandle);
@ -332,7 +334,8 @@ class JNoa {
public static native void zeroGradAdamWOptim(long adamWOptHandle); public static native void zeroGradAdamWOptim(long adamWOptHandle);
public static native long adagradOptim(long jitModuleHandle, double learningRate); public static native long adagradOptim(long jitModuleHandle, double learningRate, double weight_decay,
double lr_decay, double initial_accumulator_value, double eps);
public static native void disposeAdagradOptim(long adagradOptHandle); public static native void disposeAdagradOptim(long adagradOptHandle);
@ -340,7 +343,8 @@ class JNoa {
public static native void zeroGradAdagradOptim(long adagradOptHandle); public static native void zeroGradAdagradOptim(long adagradOptHandle);
public static native long sgdOptim(long jitModuleHandle, double learningRate); public static native long sgdOptim(long jitModuleHandle, double learningRate, double momentum,
double dampening, double weight_decay, boolean nesterov);
public static native void disposeSgdOptim(long sgdOptHandle); public static native void disposeSgdOptim(long sgdOptHandle);

View File

@ -343,17 +343,59 @@ protected constructor(scope: NoaScope) :
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser = public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate)) AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
public fun NoaJitModule.rmsOptimiser(learningRate: Double): RMSpropOptimiser = /**
RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate)) * Implements RMSprop algorithm. Receive `learning rate`, `alpha` (smoothing constant),
* `eps` (term added to the denominator to improve numerical stability), `weight_decay`,
* `momentum` factor, `centered` (if True, compute the centered RMSProp).
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
*
* @receiver the `learning rate`, `alpha`, `eps`, `weight_decay`, `momentum`, `centered`.
* @return RMSpropOptimiser.
*/
public fun NoaJitModule.rmsOptimiser(learningRate: Double, alpha: Double,
eps: Double, weightDecay: Double, momentum: Double, centered: Boolean): RMSpropOptimiser =
RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate, alpha,
eps, weightDecay, momentum, centered))
public fun NoaJitModule.adamWOptimiser(learningRate: Double): AdamWOptimiser = /**
AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate)) * Implements AdamW algorithm. Receive `learning rate`, `beta1` and `beta2` (coefficients used
* for computing running averages of gradient and its square), `eps` (term added to the denominator
* to improve numerical stability), `weight_decay`, `amsgrad`.
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
*
* @receiver the `learning rate`, `beta1`, `beta2`, `eps`, `weight_decay`, `amsgrad`.
* @return AdamWOptimiser.
*/
public fun NoaJitModule.adamWOptimiser(learningRate: Double, beta1: Double,
beta2: Double, eps: Double, weightDecay: Double, amsgrad: Boolean): AdamWOptimiser =
AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate, beta1,
beta2, eps, weightDecay, amsgrad))
public fun NoaJitModule.adagradOptimiser(learningRate: Double): AdagradOptimiser = /**
AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate)) * Implements Adagrad algorithm. Receive `learning rate`, `weight_decay`,
* `learning rate decay`, `initial accumulator value`, `eps`.
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html
*
* @receiver the `learning rate`, `weight_decay`, `learning rate decay`, `initial accumulator value`, `eps`.
* @return AdagradOptimiser.
*/
public fun NoaJitModule.adagradOptimiser(learningRate: Double, weightDecay: Double,
lrDecay: Double, initialAccumulatorValue: Double, eps: Double): AdagradOptimiser =
AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate, weightDecay,
lrDecay, initialAccumulatorValue, eps))
public fun NoaJitModule.sgdOptimiser(learningRate: Double): SgdOptimiser = /**
SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate)) * Implements stochastic gradient descent. Receive `learning rate`, `momentum` factor,
* `dampening` for momentum, `weight_decay`, `nesterov` (enables Nesterov momentum).
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
*
* @receiver the `learning rate`, `momentum`, `dampening`, `weight_decay`, `nesterov`.
* @return SgdOptimiser.
*/
public fun NoaJitModule.sgdOptimiser(learningRate: Double, momentum: Double,
dampening: Double, weightDecay: Double, nesterov: Boolean): SgdOptimiser =
SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate, momentum,
dampening, weightDecay, nesterov))
} }
public sealed class NoaDoubleAlgebra public sealed class NoaDoubleAlgebra
@ -734,4 +776,3 @@ protected constructor(scope: NoaScope) :
override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit = override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array) JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
} }

View File

@ -71,7 +71,7 @@ class TestJitModules {
val yPred = netModule.forward(xTrain) val yPred = netModule.forward(xTrain)
val loss = lossModule.forward(yPred) val loss = lossModule.forward(yPred)
val optimiser = netModule.rmsOptimiser(0.005) val optimiser = netModule.rmsOptimiser(0.005, 0.99, 1e-08, 0.0, 0.0, false)
repeat(250){ repeat(250){
optimiser.zeroGrad() optimiser.zeroGrad()
@ -107,7 +107,7 @@ class TestJitModules {
val yPred = netModule.forward(xTrain) val yPred = netModule.forward(xTrain)
val loss = lossModule.forward(yPred) val loss = lossModule.forward(yPred)
val optimiser = netModule.adamWOptimiser(0.005) val optimiser = netModule.adamWOptimiser(0.005, 0.9, 0.999, 1e-08, 0.01, false)
repeat(250){ repeat(250){
optimiser.zeroGrad() optimiser.zeroGrad()
@ -143,7 +143,7 @@ class TestJitModules {
val yPred = netModule.forward(xTrain) val yPred = netModule.forward(xTrain)
val loss = lossModule.forward(yPred) val loss = lossModule.forward(yPred)
val optimiser = netModule.adagradOptimiser(0.005) val optimiser = netModule.adagradOptimiser(0.05, 0.0, 0.0, 0.0, 1e-10)
repeat(250){ repeat(250){
optimiser.zeroGrad() optimiser.zeroGrad()
@ -179,9 +179,9 @@ class TestJitModules {
val yPred = netModule.forward(xTrain) val yPred = netModule.forward(xTrain)
val loss = lossModule.forward(yPred) val loss = lossModule.forward(yPred)
val optimiser = netModule.sgdOptimiser(0.005) val optimiser = netModule.sgdOptimiser(0.01, 0.9, 0.0, 0.0, false)
repeat(250){ repeat(400){
optimiser.zeroGrad() optimiser.zeroGrad()
netModule.forwardAssign(xTrain, yPred) netModule.forwardAssign(xTrain, yPred)
lossModule.forwardAssign(yPred, loss) lossModule.forwardAssign(yPred, loss)
@ -195,4 +195,4 @@ class TestJitModules {
assertTrue(loss.value() < 0.1) assertTrue(loss.value() < 0.1)
}!! }!!
} }