diff --git a/kmath-noa/README.md b/kmath-noa/README.md index 0774a9800..502169f20 100644 --- a/kmath-noa/README.md +++ b/kmath-noa/README.md @@ -45,6 +45,8 @@ To load the native library you will need to add to the VM options: ## Usage +###Linear Algebra + We implement the tensor algebra interfaces from [kmath-tensors](../kmath-tensors): ```kotlin @@ -62,10 +64,9 @@ NoaFloat { tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)) } ``` - +### Automatic Differentiation The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) engine is exposed: - ```kotlin NoaFloat { // Create a quadratic function @@ -88,11 +89,122 @@ NoaFloat { val hessianAtX = expressionAtX.autoHessian(tensorX) } ``` +### Deep Learning +You can train any [TorchScript](https://pytorch.org/docs/stable/jit.html) model. +For example, you can build in `python` the following neural network +and prepare the training data: +```python +import torch +n_tr = 7 +n_val = 300 +x_val = torch.linspace(-5, 5, n_val).view(-1, 1) +y_val = torch.sin(x_val) +x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1) +y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1 + +class Data(torch.nn.Module): + def __init__(self): + super(Data, self).__init__() + self.register_buffer('x_val', x_val) + self.register_buffer('y_val', y_val) + self.register_buffer('x_train', x_train) + self.register_buffer('y_train', y_train) + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.l1 = torch.nn.Linear(1, 10, bias = True) + self.l2 = torch.nn.Linear(10, 10, bias = True) + self.l3 = torch.nn.Linear(10, 1, bias = True) + + def forward(self, x): + x = self.l1(x) + x = torch.relu(x) + x = self.l2(x) + x = torch.relu(x) + x = self.l3(x) + return x + +class Loss(torch.nn.Module): + def __init__(self, target): + super(Loss, self).__init__() + self.register_buffer('target', target) + self.loss = torch.nn.MSELoss() + + def forward(self, x): + return self.loss(x, self.target) + +# Generated TorchScript modules and serialise +torch.jit.script(Data()).save('data.pt') +torch.jit.script(Net()).save('net.pt') +torch.jit.script(Loss(y_train)).save('loss.pt') +``` + +You can then load the modules into `kotlin` and train them: +```kotlin +NoaFloat { + + // Load the serialised JIT modules + // The training data + val dataModule = loadJitModule("data.pt") + // The DL model + val netModule = loadJitModule("net.pt") + // The loss function + val lossModule = loadJitModule("loss.pt") + + // Get the tensors from the module + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + // Set the model in training mode + netModule.train(true) + // Loss function set for training regime + lossModule.setBuffer("target", yTrain) + + // Compute the predictions + val yPred = netModule.forward(xTrain) + // Compute the training error + val loss = lossModule.forward(yPred) + println(loss) + + // Set-up the Adam optimiser with learning rate 0.005 + val optimiser = netModule.adamOptimiser(0.005) + + // Train for 250 epochs + repeat(250){ + // Clean gradients + optimiser.zeroGrad() + // Use forwardAssign to for better memory management + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + // Backward pass + loss.backward() + // Update model parameters + optimiser.step() + if(it % 50 == 0) + println("Training loss: $loss") + } + + // Finally validate the model + // Compute the predictions for the validation features + netModule.forwardAssign(xVal, yPred) + // Set the loss for the true values + lossModule.setBuffer("target", yVal) + // Compute the loss on validation dataset + lossModule.forwardAssign(yPred, loss) + println("Validation loss: $loss") + } + +``` + +### Custom memory management Native memory management relies on scoping with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt) -which is readily within an algebra context. +which is readily available within an algebra context. Manual management is also possible: ```kotlin // Create a scope @@ -107,4 +219,5 @@ val tensor = NoaFloat(scope){ scope.disposeAll() // Attempts to use tensor here is undefined behaviour -``` \ No newline at end of file +``` +Contributed by [Roland Grinis](https://github.com/grinisrit) diff --git a/kmath-noa/build.gradle.kts b/kmath-noa/build.gradle.kts index ce4db4337..85933eb4b 100644 --- a/kmath-noa/build.gradle.kts +++ b/kmath-noa/build.gradle.kts @@ -160,8 +160,7 @@ tasks["compileJava"].dependsOn(buildCpp) tasks { withType{ - systemProperty("java.library.path", //"$home/devspace/noa/cmake-build-release/kmath") - "$cppBuildDir/kmath") + systemProperty("java.library.path", "$cppBuildDir/kmath") } } diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 17677c01f..24e3c2b05 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -298,7 +298,7 @@ class JNoa { public static native long forwardPass(long jitModuleHandle, long tensorHandle); - public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle); + public static native void forwardPassAssign(long jitModuleHandle, long featuresHandle, long predsHandle); public static native long getModuleParameter(long jitModuleHandle, String name); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 94a2a0192..2d880656b 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -141,11 +141,11 @@ protected constructor(protected val scope: NoaScope) : public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule - public fun NoaJitModule.forward(parameters: Tensor): TensorType = - wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle)) + public fun NoaJitModule.forward(features: Tensor): TensorType = + wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle)) - public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = - JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle) + public fun NoaJitModule.forwardAssign(features: TensorType, predictions: TensorType): Unit = + JNoa.forwardPassAssign(jitModuleHandle, features.tensorHandle, predictions.tensorHandle) public fun NoaJitModule.getParameter(name: String): TensorType = wrap(JNoa.getModuleParameter(jitModuleHandle, name)) @@ -154,7 +154,7 @@ protected constructor(protected val scope: NoaScope) : JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle) public fun NoaJitModule.getBuffer(name: String): TensorType = - wrap(JNoa.getModuleParameter(jitModuleHandle, name)) + wrap(JNoa.getModuleBuffer(jitModuleHandle, name)) public fun NoaJitModule.setBuffer(name: String, buffer: Tensor): Unit = JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle) diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index fedf394fa..f058e859e 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1122,10 +1122,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass /* * Class: space_kscience_kmath_noa_JNoa * Method: forwardPassAssign - * Signature: (JJ)V + * Signature: (JJJ)V */ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign - (JNIEnv *, jclass, jlong, jlong); + (JNIEnv *, jclass, jlong, jlong, jlong); /* * Class: space_kscience_kmath_noa_JNoa diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt new file mode 100644 index 000000000..0d6fbdbdc --- /dev/null +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestJitModules.kt @@ -0,0 +1,55 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.noa + +import java.io.File +import kotlin.test.Test +import kotlin.test.assertTrue + +class TestJitModules { + + private val resources = File("").resolve("src/test/resources") + private val dataPath = resources.resolve("data.pt").absolutePath + private val netPath = resources.resolve("net.pt").absolutePath + private val lossPath = resources.resolve("loss.pt").absolutePath + + @Test + fun testOptimisation() = NoaFloat { + + setSeed(SEED) + + val dataModule = loadJitModule(dataPath) + val netModule = loadJitModule(netPath) + val lossModule = loadJitModule(lossPath) + + val xTrain = dataModule.getBuffer("x_train") + val yTrain = dataModule.getBuffer("y_train") + val xVal = dataModule.getBuffer("x_val") + val yVal = dataModule.getBuffer("y_val") + + netModule.train(true) + lossModule.setBuffer("target", yTrain) + + val yPred = netModule.forward(xTrain) + val loss = lossModule.forward(yPred) + val optimiser = netModule.adamOptimiser(0.005) + + repeat(250){ + optimiser.zeroGrad() + netModule.forwardAssign(xTrain, yPred) + lossModule.forwardAssign(yPred, loss) + loss.backward() + optimiser.step() + } + + netModule.forwardAssign(xVal, yPred) + lossModule.setBuffer("target", yVal) + lossModule.forwardAssign(yPred, loss) + + assertTrue(loss.value() < 0.1) + }!! + +} \ No newline at end of file diff --git a/kmath-noa/src/test/resources/data.pt b/kmath-noa/src/test/resources/data.pt new file mode 100644 index 000000000..17eee0d28 Binary files /dev/null and b/kmath-noa/src/test/resources/data.pt differ diff --git a/kmath-noa/src/test/resources/jitscripts.py b/kmath-noa/src/test/resources/jitscripts.py new file mode 100644 index 000000000..b199deef8 --- /dev/null +++ b/kmath-noa/src/test/resources/jitscripts.py @@ -0,0 +1,52 @@ +import torch + +torch.manual_seed(987654) + +n_tr = 7 +n_val = 300 + +x_val = torch.linspace(-5, 5, n_val).view(-1, 1) +y_val = torch.sin(x_val) +x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1) +y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1 + + +class Data(torch.nn.Module): + def __init__(self): + super(Data, self).__init__() + self.register_buffer('x_val', x_val) + self.register_buffer('y_val', y_val) + self.register_buffer('x_train', x_train) + self.register_buffer('y_train', y_train) + + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.l1 = torch.nn.Linear(1, 10, bias = True) + self.l2 = torch.nn.Linear(10, 10, bias = True) + self.l3 = torch.nn.Linear(10, 1, bias = True) + + def forward(self, x): + x = self.l1(x) + x = torch.relu(x) + x = self.l2(x) + x = torch.relu(x) + x = self.l3(x) + return x + +class Loss(torch.nn.Module): + def __init__(self, target): + super(Loss, self).__init__() + self.register_buffer('target', target) + self.loss = torch.nn.MSELoss() + + def forward(self, x): + return self.loss(x, self.target) + + +torch.jit.script(Data()).save('data.pt') + +torch.jit.script(Net()).save('net.pt') + +torch.jit.script(Loss(y_train)).save('loss.pt') diff --git a/kmath-noa/src/test/resources/loss.pt b/kmath-noa/src/test/resources/loss.pt new file mode 100644 index 000000000..091dd58e9 Binary files /dev/null and b/kmath-noa/src/test/resources/loss.pt differ diff --git a/kmath-noa/src/test/resources/net.pt b/kmath-noa/src/test/resources/net.pt new file mode 100644 index 000000000..f1ec8b1f2 Binary files /dev/null and b/kmath-noa/src/test/resources/net.pt differ