From b288c4ce590fcccb806bd82dae9502a2206486f3 Mon Sep 17 00:00:00 2001 From: Anastasia Golovina Date: Mon, 25 Apr 2022 22:17:07 +0300 Subject: [PATCH] add adamw, rms, adagrad, sgd optimizers --- .../java/space/kscience/kmath/noa/JNoa.java | 32 ++++ .../space/kscience/kmath/noa/algebras.kt | 13 ++ .../kotlin/space/kscience/kmath/noa/optim.kt | 35 ++++- .../kscience/kmath/noa/TestJitModules.kt | 147 +++++++++++++++++- 4 files changed, 224 insertions(+), 3 deletions(-) diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 5f1d56d34..c5ca892d2 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -315,6 +315,38 @@ class JNoa { public static native void stepAdamOptim(long adamOptHandle); public static native void zeroGradAdamOptim(long adamOptHandle); + + public static native long rmsOptim(long jitModuleHandle, double learningRate); + + public static native void disposeRmsOptim(long rmsOptHandle); + + public static native void stepRmsOptim(long rmsOptHandle); + + public static native void zeroGradRmsOptim(long rmsOptHandle); + + public static native long adamWOptim(long jitModuleHandle, double learningRate); + + public static native void disposeAdamWOptim(long adamWOptHandle); + + public static native void stepAdamWOptim(long adamWOptHandle); + + public static native void zeroGradAdamWOptim(long adamWOptHandle); + + public static native long adagradOptim(long jitModuleHandle, double learningRate); + + public static native void disposeAdagradOptim(long adagradOptHandle); + + public static native void stepAdagradOptim(long adagradOptHandle); + + public static native void zeroGradAdagradOptim(long adagradOptHandle); + + public static native long sgdOptim(long jitModuleHandle, double learningRate); + + public static native void disposeSgdOptim(long sgdOptHandle); + + public static native void stepSgdOptim(long sgdOptHandle); + + public static native void zeroGradSgdOptim(long sgdOptHandle); public static native void swapTensors(long lhsHandle, long rhsHandle); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index edf380f24..f9113acf7 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -342,6 +342,18 @@ protected constructor(scope: NoaScope) : public fun NoaJitModule.adamOptimiser(learningRate: Double): AdamOptimiser = AdamOptimiser(scope, JNoa.adamOptim(jitModuleHandle, learningRate)) + + public fun NoaJitModule.rmsOptimiser(learningRate: Double): RMSpropOptimiser = + RMSpropOptimiser(scope, JNoa.rmsOptim(jitModuleHandle, learningRate)) + + public fun NoaJitModule.adamWOptimiser(learningRate: Double): AdamWOptimiser = + AdamWOptimiser(scope, JNoa.adamWOptim(jitModuleHandle, learningRate)) + + public fun NoaJitModule.adagradOptimiser(learningRate: Double): AdagradOptimiser = + AdagradOptimiser(scope, JNoa.adagradOptim(jitModuleHandle, learningRate)) + + public fun NoaJitModule.sgdOptimiser(learningRate: Double): SgdOptimiser = + SgdOptimiser(scope, JNoa.sgdOptim(jitModuleHandle, learningRate)) } public sealed class NoaDoubleAlgebra @@ -722,3 +734,4 @@ protected constructor(scope: NoaScope) : override fun NoaIntTensor.set(dim: Int, slice: Slice, array: IntArray): Unit = JNoa.setSliceBlobInt(tensorHandle, dim, slice.first, slice.second, array) } + diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt index bf91968dc..ca29d5fe3 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/optim.kt @@ -22,4 +22,37 @@ internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHan override fun dispose(): Unit = JNoa.disposeAdamOptim(optimiserHandle) override fun step(): Unit = JNoa.stepAdamOptim(optimiserHandle) override fun zeroGrad(): Unit = JNoa.zeroGradAdamOptim(optimiserHandle) -} \ No newline at end of file +} + +public class RMSpropOptimiser +internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle) + : NoaOptimiser(scope) { + override fun dispose(): Unit = JNoa.disposeRmsOptim(optimiserHandle) + override fun step(): Unit = JNoa.stepRmsOptim(optimiserHandle) + override fun zeroGrad(): Unit = JNoa.zeroGradRmsOptim(optimiserHandle) +} + + +public class AdamWOptimiser +internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle) + : NoaOptimiser(scope) { + override fun dispose(): Unit = JNoa.disposeAdamWOptim(optimiserHandle) + override fun step(): Unit = JNoa.stepAdamWOptim(optimiserHandle) + override fun zeroGrad(): Unit = JNoa.zeroGradAdamWOptim(optimiserHandle) +} + +public class AdagradOptimiser +internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle) + : NoaOptimiser(scope) { + override fun dispose(): Unit = JNoa.disposeAdagradOptim(optimiserHandle) + override fun step(): Unit = JNoa.stepAdagradOptim(optimiserHandle) + override fun zeroGrad(): Unit = JNoa.zeroGradAdagradOptim(optimiserHandle) +} + +public class SgdOptimiser +internal constructor(scope: NoaScope, internal val optimiserHandle: OptimiserHandle) + : NoaOptimiser(scope) { + override fun dispose(): Unit = JNoa.disposeSgdOptim(optimiserHandle) + override fun step(): Unit = JNoa.stepSgdOptim(optimiserHandle) + override fun zeroGrad(): Unit = JNoa.zeroGradSgdOptim(optimiserHandle) +} diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt index 0d6fbdbdc..2f4c81847 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt @@ -17,7 +17,7 @@ class TestJitModules { private val lossPath = resources.resolve("loss.pt").absolutePath @Test - fun testOptimisation() = NoaFloat { + fun testOptimisationAdam() = NoaFloat { setSeed(SEED) @@ -52,4 +52,147 @@ class TestJitModules { assertTrue(loss.value() < 0.1) }!! -} \ No newline at end of file + @Test + fun testOptimisationRms() = NoaFloat { + + setSeed(SEED) + + val dataModule = loadJitModule(dataPath) + val netModule = loadJitModule(netPath) + val lossModule = loadJitModule(lossPath) + + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + netModule.train(true) + lossModule.setBuffer("target", yTrain) + + val yPred = netModule.forward(xTrain) + val loss = lossModule.forward(yPred) + val optimiser = netModule.rmsOptimiser(0.005) + + repeat(250){ + optimiser.zeroGrad() + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + loss.backward() + optimiser.step() + } + + netModule.forwardAssign(xVal, yPred) + lossModule.setBuffer("target", yVal) + lossModule.forwardAssign(yPred, loss) + + assertTrue(loss.value() < 0.1) + }!! + + @Test + fun testOptimisationAdamW() = NoaFloat { + + setSeed(SEED) + + val dataModule = loadJitModule(dataPath) + val netModule = loadJitModule(netPath) + val lossModule = loadJitModule(lossPath) + + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + netModule.train(true) + lossModule.setBuffer("target", yTrain) + + val yPred = netModule.forward(xTrain) + val loss = lossModule.forward(yPred) + val optimiser = netModule.adamWOptimiser(0.005) + + repeat(250){ + optimiser.zeroGrad() + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + loss.backward() + optimiser.step() + } + + netModule.forwardAssign(xVal, yPred) + lossModule.setBuffer("target", yVal) + lossModule.forwardAssign(yPred, loss) + + assertTrue(loss.value() < 0.1) + }!! + + @Test + fun testOptimisationAdagrad() = NoaFloat { + + setSeed(SEED) + + val dataModule = loadJitModule(dataPath) + val netModule = loadJitModule(netPath) + val lossModule = loadJitModule(lossPath) + + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + netModule.train(true) + lossModule.setBuffer("target", yTrain) + + val yPred = netModule.forward(xTrain) + val loss = lossModule.forward(yPred) + val optimiser = netModule.adagradOptimiser(0.005) + + repeat(250){ + optimiser.zeroGrad() + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + loss.backward() + optimiser.step() + } + + netModule.forwardAssign(xVal, yPred) + lossModule.setBuffer("target", yVal) + lossModule.forwardAssign(yPred, loss) + + assertTrue(loss.value() < 0.1) + }!! + + @Test + fun testOptimisationSgd() = NoaFloat { + + setSeed(SEED) + + val dataModule = loadJitModule(dataPath) + val netModule = loadJitModule(netPath) + val lossModule = loadJitModule(lossPath) + + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + netModule.train(true) + lossModule.setBuffer("target", yTrain) + + val yPred = netModule.forward(xTrain) + val loss = lossModule.forward(yPred) + val optimiser = netModule.sgdOptimiser(0.005) + + repeat(250){ + optimiser.zeroGrad() + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + loss.backward() + optimiser.step() + } + + netModule.forwardAssign(xVal, yPred) + lossModule.setBuffer("target", yVal) + lossModule.forwardAssign(yPred, loss) + + assertTrue(loss.value() < 0.1) + }!! +}