Add AdamW, RMSProp, Adagrad, SGD optimizers #483
@ -316,6 +316,42 @@ class JNoa {
|
||||
|
||||
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||
|
||||
public static native long rmsOptim(long jitModuleHandle, double learningRate, double alpha,
|
||||
double eps, double weight_decay, double momentum, boolean centered);
|
||||
|
||||
public static native void disposeRmsOptim(long rmsOptHandle);
|
||||
|
||||
public static native void stepRmsOptim(long rmsOptHandle);
|
||||
|
||||
public static native void zeroGradRmsOptim(long rmsOptHandle);
|
||||
|
||||
public static native long adamWOptim(long jitModuleHandle, double learningRate, double beta1,
|
||||
double beta2, double eps, double weight_decay, boolean amsgrad);
|
||||
|
||||
public static native void disposeAdamWOptim(long adamWOptHandle);
|
||||
|
||||
public static native void stepAdamWOptim(long adamWOptHandle);
|
||||
|
||||
public static native void zeroGradAdamWOptim(long adamWOptHandle);
|
||||
|
||||
public static native long adagradOptim(long jitModuleHandle, double learningRate, double weight_decay,
|
||||
double lr_decay, double initial_accumulator_value, double eps);
|
||||
|
||||
public static native void disposeAdagradOptim(long adagradOptHandle);
|
||||
|
||||
public static native void stepAdagradOptim(long adagradOptHandle);
|
||||
|
||||
public static native void zeroGradAdagradOptim(long adagradOptHandle);
|
||||
|
||||
public static native long sgdOptim(long jitModuleHandle, double learningRate, double momentum,
|
||||
double dampening, double weight_decay, boolean nesterov);
|
||||
|
||||
public static native void disposeSgdOptim(long sgdOptHandle);
|
||||
|
||||
public static native void stepSgdOptim(long sgdOptHandle);
|
||||
|
||||
public static native void zeroGradSgdOptim(long sgdOptHandle);
|
||||
|
||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||
|
||||
public static native long loadTensorDouble(String path, int device);
|
||||
|
@ -342,6 +342,60 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
|
||||
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
|
||||
|
||||
/**
|
||||
* Implements RMSprop algorithm. Receive `learning rate`, `alpha` (smoothing constant),
|
||||
* `eps` (term added to the denominator to improve numerical stability), `weight_decay`,
|
||||
* `momentum` factor, `centered` (if True, compute the centered RMSProp).
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
|
||||
*
|
||||
* @receiver the `learning rate`, `alpha`, `eps`, `weight_decay`, `momentum`, `centered`.
|
||||
* @return RMSpropOptimiser.
|
||||
*/
|
||||
public fun NoaJitModule.rmsOptimiser(learningRate: Double, alpha: Double,
|
||||
eps: Double, weightDecay: Double, momentum: Double, centered: Boolean): RMSpropOptimiser =
|
||||
RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate, alpha,
|
||||
eps, weightDecay, momentum, centered))
|
||||
|
||||
/**
|
||||
* Implements AdamW algorithm. Receive `learning rate`, `beta1` and `beta2` (coefficients used
|
||||
* for computing running averages of gradient and its square), `eps` (term added to the denominator
|
||||
* to improve numerical stability), `weight_decay`, `amsgrad`.
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
|
||||
*
|
||||
* @receiver the `learning rate`, `beta1`, `beta2`, `eps`, `weight_decay`, `amsgrad`.
|
||||
* @return AdamWOptimiser.
|
||||
*/
|
||||
public fun NoaJitModule.adamWOptimiser(learningRate: Double, beta1: Double,
|
||||
beta2: Double, eps: Double, weightDecay: Double, amsgrad: Boolean): AdamWOptimiser =
|
||||
AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate, beta1,
|
||||
beta2, eps, weightDecay, amsgrad))
|
||||
|
||||
/**
|
||||
* Implements Adagrad algorithm. Receive `learning rate`, `weight_decay`,
|
||||
* `learning rate decay`, `initial accumulator value`, `eps`.
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html
|
||||
*
|
||||
* @receiver the `learning rate`, `weight_decay`, `learning rate decay`, `initial accumulator value`, `eps`.
|
||||
* @return AdagradOptimiser.
|
||||
*/
|
||||
public fun NoaJitModule.adagradOptimiser(learningRate: Double, weightDecay: Double,
|
||||
lrDecay: Double, initialAccumulatorValue: Double, eps: Double): AdagradOptimiser =
|
||||
AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate, weightDecay,
|
||||
lrDecay, initialAccumulatorValue, eps))
|
||||
|
||||
/**
|
||||
* Implements stochastic gradient descent. Receive `learning rate`, `momentum` factor,
|
||||
* `dampening` for momentum, `weight_decay`, `nesterov` (enables Nesterov momentum).
|
||||
* For more information: https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
||||
*
|
||||
* @receiver the `learning rate`, `momentum`, `dampening`, `weight_decay`, `nesterov`.
|
||||
* @return SgdOptimiser.
|
||||
*/
|
||||
public fun NoaJitModule.sgdOptimiser(learningRate: Double, momentum: Double,
|
||||
dampening: Double, weightDecay: Double, nesterov: Boolean): SgdOptimiser =
|
||||
SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate, momentum,
|
||||
dampening, weightDecay, nesterov))
|
||||
}
|
||||
|
||||
public sealed class NoaDoubleAlgebra
|
||||
|
@ -23,3 +23,36 @@ internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHan
|
||||
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
|
||||
}
|
||||
|
||||
public class RMSpropOptimiser
|
||||
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||
: NoaOptimiser(scope) {
|
||||
override fun dispose(): Unit = JNoa.disposeRmsOptim(optimiserHandle)
|
||||
override fun step(): Unit = JNoa.stepRmsOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradRmsOptim(optimiserHandle)
|
||||
}
|
||||
|
||||
|
||||
public class AdamWOptimiser
|
||||
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||
: NoaOptimiser(scope) {
|
||||
override fun dispose(): Unit = JNoa.disposeAdamWOptim(optimiserHandle)
|
||||
override fun step(): Unit = JNoa.stepAdamWOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradAdamWOptim(optimiserHandle)
|
||||
}
|
||||
|
||||
public class AdagradOptimiser
|
||||
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||
: NoaOptimiser(scope) {
|
||||
override fun dispose(): Unit = JNoa.disposeAdagradOptim(optimiserHandle)
|
||||
override fun step(): Unit = JNoa.stepAdagradOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradAdagradOptim(optimiserHandle)
|
||||
}
|
||||
|
||||
public class SgdOptimiser
|
||||
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||
: NoaOptimiser(scope) {
|
||||
override fun dispose(): Unit = JNoa.disposeSgdOptim(optimiserHandle)
|
||||
override fun step(): Unit = JNoa.stepSgdOptim(optimiserHandle)
|
||||
override fun zeroGrad(): Unit = JNoa.zeroGradSgdOptim(optimiserHandle)
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ class TestJitModules {
|
||||
private val lossPath = resources.resolve("loss.pt").absolutePath
|
||||
|
||||
@Test
|
||||
fun testOptimisation() = NoaFloat {
|
||||
fun testOptimisationAdam() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
@ -52,4 +52,147 @@ class TestJitModules {
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testOptimisationRms() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
val dataModule = loadJitModule(dataPath)
|
||||
val netModule = loadJitModule(netPath)
|
||||
val lossModule = loadJitModule(lossPath)
|
||||
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
netModule.train(true)
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
val yPred = netModule.forward(xTrain)
|
||||
val loss = lossModule.forward(yPred)
|
||||
val optimiser = netModule.rmsOptimiser(0.005, 0.99, 1e-08, 0.0, 0.0, false)
|
||||
|
||||
repeat(250){
|
||||
optimiser.zeroGrad()
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
loss.backward()
|
||||
optimiser.step()
|
||||
}
|
||||
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
lossModule.setBuffer("target", yVal)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testOptimisationAdamW() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
val dataModule = loadJitModule(dataPath)
|
||||
val netModule = loadJitModule(netPath)
|
||||
val lossModule = loadJitModule(lossPath)
|
||||
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
netModule.train(true)
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
val yPred = netModule.forward(xTrain)
|
||||
val loss = lossModule.forward(yPred)
|
||||
val optimiser = netModule.adamWOptimiser(0.005, 0.9, 0.999, 1e-08, 0.01, false)
|
||||
|
||||
repeat(250){
|
||||
optimiser.zeroGrad()
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
loss.backward()
|
||||
optimiser.step()
|
||||
}
|
||||
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
lossModule.setBuffer("target", yVal)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testOptimisationAdagrad() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
val dataModule = loadJitModule(dataPath)
|
||||
val netModule = loadJitModule(netPath)
|
||||
val lossModule = loadJitModule(lossPath)
|
||||
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
netModule.train(true)
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
val yPred = netModule.forward(xTrain)
|
||||
val loss = lossModule.forward(yPred)
|
||||
val optimiser = netModule.adagradOptimiser(0.05, 0.0, 0.0, 0.0, 1e-10)
|
||||
|
||||
repeat(250){
|
||||
optimiser.zeroGrad()
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
loss.backward()
|
||||
optimiser.step()
|
||||
}
|
||||
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
lossModule.setBuffer("target", yVal)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testOptimisationSgd() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
val dataModule = loadJitModule(dataPath)
|
||||
val netModule = loadJitModule(netPath)
|
||||
val lossModule = loadJitModule(lossPath)
|
||||
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
netModule.train(true)
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
val yPred = netModule.forward(xTrain)
|
||||
val loss = lossModule.forward(yPred)
|
||||
val optimiser = netModule.sgdOptimiser(0.01, 0.9, 0.0, 0.0, false)
|
||||
|
||||
repeat(400){
|
||||
optimiser.zeroGrad()
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
loss.backward()
|
||||
optimiser.step()
|
||||
}
|
||||
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
lossModule.setBuffer("target", yVal)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
}
|
Loading…
Reference in New Issue
Block a user