add adamw, rms, adagrad, sgd optimizers
This commit is contained in:
parent
c53bdd38f8
commit
b288c4ce59
@ -316,6 +316,38 @@ class JNoa {
|
|||||||
|
|
||||||
public static native void zeroGradAdamOptim(long adamOptHandle);
|
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||||
|
|
||||||
|
public static native long rmsOptim(long jitModuleHandle, double learningRate);
|
||||||
|
|
||||||
|
public static native void disposeRmsOptim(long rmsOptHandle);
|
||||||
|
|
||||||
|
public static native void stepRmsOptim(long rmsOptHandle);
|
||||||
|
|
||||||
|
public static native void zeroGradRmsOptim(long rmsOptHandle);
|
||||||
|
|
||||||
|
public static native long adamWOptim(long jitModuleHandle, double learningRate);
|
||||||
|
|
||||||
|
public static native void disposeAdamWOptim(long adamWOptHandle);
|
||||||
|
|
||||||
|
public static native void stepAdamWOptim(long adamWOptHandle);
|
||||||
|
|
||||||
|
public static native void zeroGradAdamWOptim(long adamWOptHandle);
|
||||||
|
|
||||||
|
public static native long adagradOptim(long jitModuleHandle, double learningRate);
|
||||||
|
|
||||||
|
public static native void disposeAdagradOptim(long adagradOptHandle);
|
||||||
|
|
||||||
|
public static native void stepAdagradOptim(long adagradOptHandle);
|
||||||
|
|
||||||
|
public static native void zeroGradAdagradOptim(long adagradOptHandle);
|
||||||
|
|
||||||
|
public static native long sgdOptim(long jitModuleHandle, double learningRate);
|
||||||
|
|
||||||
|
public static native void disposeSgdOptim(long sgdOptHandle);
|
||||||
|
|
||||||
|
public static native void stepSgdOptim(long sgdOptHandle);
|
||||||
|
|
||||||
|
public static native void zeroGradSgdOptim(long sgdOptHandle);
|
||||||
|
|
||||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||||
|
|
||||||
public static native long loadTensorDouble(String path, int device);
|
public static native long loadTensorDouble(String path, int device);
|
||||||
|
@ -342,6 +342,18 @@ protected constructor(scope: NoaScope) :
|
|||||||
|
|
||||||
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
|
public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser =
|
||||||
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
|
AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate))
|
||||||
|
|
||||||
|
public fun NoaJitModule.rmsOptimiser(learningRate: Double): RMSpropOptimiser =
|
||||||
|
RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate))
|
||||||
|
|
||||||
|
public fun NoaJitModule.adamWOptimiser(learningRate: Double): AdamWOptimiser =
|
||||||
|
AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate))
|
||||||
|
|
||||||
|
public fun NoaJitModule.adagradOptimiser(learningRate: Double): AdagradOptimiser =
|
||||||
|
AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate))
|
||||||
|
|
||||||
|
public fun NoaJitModule.sgdOptimiser(learningRate: Double): SgdOptimiser =
|
||||||
|
SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate))
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaDoubleAlgebra
|
public sealed class NoaDoubleAlgebra
|
||||||
@ -722,3 +734,4 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
|
override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit =
|
||||||
JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
|
JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,3 +23,36 @@ internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHan
|
|||||||
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
|
override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle)
|
||||||
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
|
override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class RMSpropOptimiser
|
||||||
|
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||||
|
: NoaOptimiser(scope) {
|
||||||
|
override fun dispose(): Unit = JNoa.disposeRmsOptim(optimiserHandle)
|
||||||
|
override fun step(): Unit = JNoa.stepRmsOptim(optimiserHandle)
|
||||||
|
override fun zeroGrad(): Unit = JNoa.zeroGradRmsOptim(optimiserHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public class AdamWOptimiser
|
||||||
|
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||||
|
: NoaOptimiser(scope) {
|
||||||
|
override fun dispose(): Unit = JNoa.disposeAdamWOptim(optimiserHandle)
|
||||||
|
override fun step(): Unit = JNoa.stepAdamWOptim(optimiserHandle)
|
||||||
|
override fun zeroGrad(): Unit = JNoa.zeroGradAdamWOptim(optimiserHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AdagradOptimiser
|
||||||
|
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||||
|
: NoaOptimiser(scope) {
|
||||||
|
override fun dispose(): Unit = JNoa.disposeAdagradOptim(optimiserHandle)
|
||||||
|
override fun step(): Unit = JNoa.stepAdagradOptim(optimiserHandle)
|
||||||
|
override fun zeroGrad(): Unit = JNoa.zeroGradAdagradOptim(optimiserHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class SgdOptimiser
|
||||||
|
internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle)
|
||||||
|
: NoaOptimiser(scope) {
|
||||||
|
override fun dispose(): Unit = JNoa.disposeSgdOptim(optimiserHandle)
|
||||||
|
override fun step(): Unit = JNoa.stepSgdOptim(optimiserHandle)
|
||||||
|
override fun zeroGrad(): Unit = JNoa.zeroGradSgdOptim(optimiserHandle)
|
||||||
|
}
|
||||||
|
@ -17,7 +17,7 @@ class TestJitModules {
|
|||||||
private val lossPath = resources.resolve("loss.pt").absolutePath
|
private val lossPath = resources.resolve("loss.pt").absolutePath
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testOptimisation() = NoaFloat {
|
fun testOptimisationAdam() = NoaFloat {
|
||||||
|
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
|
|
||||||
@ -52,4 +52,147 @@ class TestJitModules {
|
|||||||
assertTrue(loss.value() < 0.1)
|
assertTrue(loss.value() < 0.1)
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOptimisationRms() = NoaFloat {
|
||||||
|
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val dataModule = loadJitModule(dataPath)
|
||||||
|
val netModule = loadJitModule(netPath)
|
||||||
|
val lossModule = loadJitModule(lossPath)
|
||||||
|
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
netModule.train(true)
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
val optimiser = netModule.rmsOptimiser(0.005)
|
||||||
|
|
||||||
|
repeat(250){
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
loss.backward()
|
||||||
|
optimiser.step()
|
||||||
|
}
|
||||||
|
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
|
||||||
|
assertTrue(loss.value() < 0.1)
|
||||||
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOptimisationAdamW() = NoaFloat {
|
||||||
|
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val dataModule = loadJitModule(dataPath)
|
||||||
|
val netModule = loadJitModule(netPath)
|
||||||
|
val lossModule = loadJitModule(lossPath)
|
||||||
|
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
netModule.train(true)
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
val optimiser = netModule.adamWOptimiser(0.005)
|
||||||
|
|
||||||
|
repeat(250){
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
loss.backward()
|
||||||
|
optimiser.step()
|
||||||
|
}
|
||||||
|
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
|
||||||
|
assertTrue(loss.value() < 0.1)
|
||||||
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOptimisationAdagrad() = NoaFloat {
|
||||||
|
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val dataModule = loadJitModule(dataPath)
|
||||||
|
val netModule = loadJitModule(netPath)
|
||||||
|
val lossModule = loadJitModule(lossPath)
|
||||||
|
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
netModule.train(true)
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
val optimiser = netModule.adagradOptimiser(0.005)
|
||||||
|
|
||||||
|
repeat(250){
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
loss.backward()
|
||||||
|
optimiser.step()
|
||||||
|
}
|
||||||
|
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
|
||||||
|
assertTrue(loss.value() < 0.1)
|
||||||
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOptimisationSgd() = NoaFloat {
|
||||||
|
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val dataModule = loadJitModule(dataPath)
|
||||||
|
val netModule = loadJitModule(netPath)
|
||||||
|
val lossModule = loadJitModule(lossPath)
|
||||||
|
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
netModule.train(true)
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
val optimiser = netModule.sgdOptimiser(0.005)
|
||||||
|
|
||||||
|
repeat(250){
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
loss.backward()
|
||||||
|
optimiser.step()
|
||||||
|
}
|
||||||
|
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
|
||||||
|
assertTrue(loss.value() < 0.1)
|
||||||
|
}!!
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user