From 3c92bfda59a4e7774a869ac01423220b7e817136 Mon Sep 17 00:00:00 2001 From: Anastasia Golovina Date: Mon, 9 May 2022 18:22:05 +0300 Subject: [PATCH] add options for optimizers --- .../java/space/kscience/kmath/noa/JNoa.java | 14 +++-- .../space/kscience/kmath/noa/algebras.kt | 59 ++++++++++++++++--- .../kscience/kmath/noa/TestJitModules.kt | 12 ++-- 3 files changed, 65 insertions(+), 20 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index c5ca892d2..e6d67b0b9 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -315,8 +315,9 @@ class JNoa { public static native void stepAdamOptim(long adamOptHandle); public static native void zeroGradAdamOptim(long adamOptHandle); - - public static native long rmsOptim(long jitModuleHandle, double learningRate); + + public static native long rmsOptim(long jitModuleHandle, double learningRate, double alpha, + double eps, double weight_decay, double momentum, boolean centered); public static native void disposeRmsOptim(long rmsOptHandle); @@ -324,7 +325,8 @@ class JNoa { public static native void zeroGradRmsOptim(long rmsOptHandle); - public static native long adamWOptim(long jitModuleHandle, double learningRate); + public static native long adamWOptim(long jitModuleHandle, double learningRate, double beta1, + double beta2, double eps, double weight_decay, boolean amsgrad); public static native void disposeAdamWOptim(long adamWOptHandle); @@ -332,7 +334,8 @@ class JNoa { public static native void zeroGradAdamWOptim(long adamWOptHandle); - public static native long adagradOptim(long jitModuleHandle, double learningRate); + public static native long adagradOptim(long jitModuleHandle, double learningRate, double weight_decay, + double lr_decay, double initial_accumulator_value, double eps); public static native void disposeAdagradOptim(long adagradOptHandle); @@ -340,7 +343,8 @@ class JNoa { public static native void zeroGradAdagradOptim(long adagradOptHandle); - public static native long sgdOptim(long jitModuleHandle, double learningRate); + public static native long sgdOptim(long jitModuleHandle, double learningRate, double momentum, + double dampening, double weight_decay, boolean nesterov); public static native void disposeSgdOptim(long sgdOptHandle); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index f9113acf7..5989780a7 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -343,17 +343,59 @@ protected constructor(scope: NoaScope) : public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser = AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate)) - public fun NoaJitModule.rmsOptimiser(learningRate: Double): RMSpropOptimiser = - RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate)) + /** + * Implements RMSprop algorithm. Receive `learning rate`, `alpha` (smoothing constant), + * `eps` (term added to the denominator to improve numerical stability), `weight_decay`, + * `momentum` factor, `centered` (if True, compute the centered RMSProp). + * For more information: https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html + * + * @receiver the `learning rate`, `alpha`, `eps`, `weight_decay`, `momentum`, `centered`. + * @return RMSpropOptimiser. + */ + public fun NoaJitModule.rmsOptimiser(learningRate: Double, alpha: Double, + eps: Double, weightDecay: Double, momentum: Double, centered: Boolean): RMSpropOptimiser = + RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate, alpha, + eps, weightDecay, momentum, centered)) - public fun NoaJitModule.adamWOptimiser(learningRate: Double): AdamWOptimiser = - AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate)) + /** + * Implements AdamW algorithm. Receive `learning rate`, `beta1` and `beta2` (coefficients used + * for computing running averages of gradient and its square), `eps` (term added to the denominator + * to improve numerical stability), `weight_decay`, `amsgrad`. + * For more information: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + * + * @receiver the `learning rate`, `beta1`, `beta2`, `eps`, `weight_decay`, `amsgrad`. + * @return AdamWOptimiser. + */ + public fun NoaJitModule.adamWOptimiser(learningRate: Double, beta1: Double, + beta2: Double, eps: Double, weightDecay: Double, amsgrad: Boolean): AdamWOptimiser = + AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate, beta1, + beta2, eps, weightDecay, amsgrad)) - public fun NoaJitModule.adagradOptimiser(learningRate: Double): AdagradOptimiser = - AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate)) + /** + * Implements Adagrad algorithm. Receive `learning rate`, `weight_decay`, + * `learning rate decay`, `initial accumulator value`, `eps`. + * For more information: https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html + * + * @receiver the `learning rate`, `weight_decay`, `learning rate decay`, `initial accumulator value`, `eps`. + * @return AdagradOptimiser. + */ + public fun NoaJitModule.adagradOptimiser(learningRate: Double, weightDecay: Double, + lrDecay: Double, initialAccumulatorValue: Double, eps: Double): AdagradOptimiser = + AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate, weightDecay, + lrDecay, initialAccumulatorValue, eps)) - public fun NoaJitModule.sgdOptimiser(learningRate: Double): SgdOptimiser = - SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate)) + /** + * Implements stochastic gradient descent. Receive `learning rate`, `momentum` factor, + * `dampening` for momentum, `weight_decay`, `nesterov` (enables Nesterov momentum). + * For more information: https://pytorch.org/docs/stable/generated/torch.optim.SGD.html + * + * @receiver the `learning rate`, `momentum`, `dampening`, `weight_decay`, `nesterov`. + * @return SgdOptimiser. + */ + public fun NoaJitModule.sgdOptimiser(learningRate: Double, momentum: Double, + dampening: Double, weightDecay: Double, nesterov: Boolean): SgdOptimiser = + SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate, momentum, + dampening, weightDecay, nesterov)) } public sealed class NoaDoubleAlgebra @@ -734,4 +776,3 @@ protected constructor(scope: NoaScope) : override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit = JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array) } - diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt index 2f4c81847..90088977d 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt @@ -71,7 +71,7 @@ class TestJitModules { val yPred = netModule.forward(xTrain) val loss = lossModule.forward(yPred) - val optimiser = netModule.rmsOptimiser(0.005) + val optimiser = netModule.rmsOptimiser(0.005, 0.99, 1e-08, 0.0, 0.0, false) repeat(250){ optimiser.zeroGrad() @@ -107,7 +107,7 @@ class TestJitModules { val yPred = netModule.forward(xTrain) val loss = lossModule.forward(yPred) - val optimiser = netModule.adamWOptimiser(0.005) + val optimiser = netModule.adamWOptimiser(0.005, 0.9, 0.999, 1e-08, 0.01, false) repeat(250){ optimiser.zeroGrad() @@ -143,7 +143,7 @@ class TestJitModules { val yPred = netModule.forward(xTrain) val loss = lossModule.forward(yPred) - val optimiser = netModule.adagradOptimiser(0.005) + val optimiser = netModule.adagradOptimiser(0.05, 0.0, 0.0, 0.0, 1e-10) repeat(250){ optimiser.zeroGrad() @@ -179,9 +179,9 @@ class TestJitModules { val yPred = netModule.forward(xTrain) val loss = lossModule.forward(yPred) - val optimiser = netModule.sgdOptimiser(0.005) + val optimiser = netModule.sgdOptimiser(0.01, 0.9, 0.0, 0.0, false) - repeat(250){ + repeat(400){ optimiser.zeroGrad() netModule.forwardAssign(xTrain, yPred) lossModule.forwardAssign(yPred, loss) @@ -195,4 +195,4 @@ class TestJitModules { assertTrue(loss.value() < 0.1) }!! -} +} \ No newline at end of file