deep learning example

This commit is contained in:
Roland Grinis 2021-07-13 19:48:21 +01:00
parent 0bc2c12a05
commit cccff29378
10 changed files with 233 additions and 14 deletions

View File

@ -45,6 +45,8 @@ To load the native library you will need to add to the VM options:
## Usage ## Usage
###Linear Algebra
We implement the tensor algebra interfaces We implement the tensor algebra interfaces
from [kmath-tensors](../kmath-tensors): from [kmath-tensors](../kmath-tensors):
```kotlin ```kotlin
@ -62,10 +64,9 @@ NoaFloat {
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)) tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
} }
``` ```
### Automatic Differentiation
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
engine is exposed: engine is exposed:
```kotlin ```kotlin
NoaFloat { NoaFloat {
// Create a quadratic function // Create a quadratic function
@ -88,11 +89,122 @@ NoaFloat {
val hessianAtX = expressionAtX.autoHessian(tensorX) val hessianAtX = expressionAtX.autoHessian(tensorX)
} }
``` ```
### Deep Learning
You can train any [TorchScript](https://pytorch.org/docs/stable/jit.html) model.
For example, you can build in `python` the following neural network
and prepare the training data:
```python
import torch
n_tr = 7
n_val = 300
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
y_val = torch.sin(x_val)
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
class Data(torch.nn.Module):
def __init__(self):
super(Data, self).__init__()
self.register_buffer('x_val', x_val)
self.register_buffer('y_val', y_val)
self.register_buffer('x_train', x_train)
self.register_buffer('y_train', y_train)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(1, 10, bias = True)
self.l2 = torch.nn.Linear(10, 10, bias = True)
self.l3 = torch.nn.Linear(10, 1, bias = True)
def forward(self, x):
x = self.l1(x)
x = torch.relu(x)
x = self.l2(x)
x = torch.relu(x)
x = self.l3(x)
return x
class Loss(torch.nn.Module):
def __init__(self, target):
super(Loss, self).__init__()
self.register_buffer('target', target)
self.loss = torch.nn.MSELoss()
def forward(self, x):
return self.loss(x, self.target)
# Generated TorchScript modules and serialise
torch.jit.script(Data()).save('data.pt')
torch.jit.script(Net()).save('net.pt')
torch.jit.script(Loss(y_train)).save('loss.pt')
```
You can then load the modules into `kotlin` and train them:
```kotlin
NoaFloat {
// Load the serialised JIT modules
// The training data
val dataModule = loadJitModule("data.pt")
// The DL model
val netModule = loadJitModule("net.pt")
// The loss function
val lossModule = loadJitModule("loss.pt")
// Get the tensors from the module
val xTrain = dataModule.getBuffer("x_train")
val yTrain = dataModule.getBuffer("y_train")
val xVal = dataModule.getBuffer("x_val")
val yVal = dataModule.getBuffer("y_val")
// Set the model in training mode
netModule.train(true)
// Loss function set for training regime
lossModule.setBuffer("target", yTrain)
// Compute the predictions
val yPred = netModule.forward(xTrain)
// Compute the training error
val loss = lossModule.forward(yPred)
println(loss)
// Set-up the Adam optimiser with learning rate 0.005
val optimiser = netModule.adamOptimiser(0.005)
// Train for 250 epochs
repeat(250){
// Clean gradients
optimiser.zeroGrad()
// Use forwardAssign to for better memory management
netModule.forwardAssign(xTrain, yPred)
lossModule.forwardAssign(yPred, loss)
// Backward pass
loss.backward()
// Update model parameters
optimiser.step()
if(it % 50 == 0)
println("Training loss: $loss")
}
// Finally validate the model
// Compute the predictions for the validation features
netModule.forwardAssign(xVal, yPred)
// Set the loss for the true values
lossModule.setBuffer("target", yVal)
// Compute the loss on validation dataset
lossModule.forwardAssign(yPred, loss)
println("Validation loss: $loss")
}
```
### Custom memory management
Native memory management relies on scoping Native memory management relies on scoping
with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt) with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt)
which is readily within an algebra context. which is readily available within an algebra context.
Manual management is also possible: Manual management is also possible:
```kotlin ```kotlin
// Create a scope // Create a scope
@ -107,4 +219,5 @@ val tensor = NoaFloat(scope){
scope.disposeAll() scope.disposeAll()
// Attempts to use tensor here is undefined behaviour // Attempts to use tensor here is undefined behaviour
``` ```
Contributed by [Roland Grinis](https://github.com/grinisrit)

View File

@ -160,8 +160,7 @@ tasks["compileJava"].dependsOn(buildCpp)
tasks { tasks {
withType<Test>{ withType<Test>{
systemProperty("java.library.path", //"$home/devspace/noa/cmake-build-release/kmath") systemProperty("java.library.path", "$cppBuildDir/kmath")
"$cppBuildDir/kmath")
} }
} }

View File

@ -298,7 +298,7 @@ class JNoa {
public static native long forwardPass(long jitModuleHandle, long tensorHandle); public static native long forwardPass(long jitModuleHandle, long tensorHandle);
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle); public static native void forwardPassAssign(long jitModuleHandle, long featuresHandle, long predsHandle);
public static native long getModuleParameter(long jitModuleHandle, String name); public static native long getModuleParameter(long jitModuleHandle, String name);

View File

@ -141,11 +141,11 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType = public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle)) wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit = public fun NoaJitModule.forwardAssign(features: TensorType, predictions: TensorType): Unit =
JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle) JNoa.forwardPassAssign(jitModuleHandle, features.tensorHandle, predictions.tensorHandle)
public fun NoaJitModule.getParameter(name: String): TensorType = public fun NoaJitModule.getParameter(name: String): TensorType =
wrap(JNoa.getModuleParameter(jitModuleHandle, name)) wrap(JNoa.getModuleParameter(jitModuleHandle, name))
@ -154,7 +154,7 @@ protected constructor(protected val scope: NoaScope) :
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle) JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
public fun NoaJitModule.getBuffer(name: String): TensorType = public fun NoaJitModule.getBuffer(name: String): TensorType =
wrap(JNoa.getModuleParameter(jitModuleHandle, name)) wrap(JNoa.getModuleBuffer(jitModuleHandle, name))
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit = public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle) JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)

View File

@ -1122,10 +1122,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa
* Method: forwardPassAssign * Method: forwardPassAssign
* Signature: (JJ)V * Signature: (JJJ)V
*/ */
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
(JNIEnv *, jclass, jlong, jlong); (JNIEnv *, jclass, jlong, jlong, jlong);
/* /*
* Class: space_kscience_kmath_noa_JNoa * Class: space_kscience_kmath_noa_JNoa

View File

@ -0,0 +1,55 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
import java.io.File
import kotlin.test.Test
import kotlin.test.assertTrue
class TestJitModules {
private val resources = File("").resolve("src/test/resources")
private val dataPath = resources.resolve("data.pt").absolutePath
private val netPath = resources.resolve("net.pt").absolutePath
private val lossPath = resources.resolve("loss.pt").absolutePath
@Test
fun testOptimisation() = NoaFloat {
setSeed(SEED)
val dataModule = loadJitModule(dataPath)
val netModule = loadJitModule(netPath)
val lossModule = loadJitModule(lossPath)
val xTrain = dataModule.getBuffer("x_train")
val yTrain = dataModule.getBuffer("y_train")
val xVal = dataModule.getBuffer("x_val")
val yVal = dataModule.getBuffer("y_val")
netModule.train(true)
lossModule.setBuffer("target", yTrain)
val yPred = netModule.forward(xTrain)
val loss = lossModule.forward(yPred)
val optimiser = netModule.adamOptimiser(0.005)
repeat(250){
optimiser.zeroGrad()
netModule.forwardAssign(xTrain, yPred)
lossModule.forwardAssign(yPred, loss)
loss.backward()
optimiser.step()
}
netModule.forwardAssign(xVal, yPred)
lossModule.setBuffer("target", yVal)
lossModule.forwardAssign(yPred, loss)
assertTrue(loss.value() < 0.1)
}!!
}

Binary file not shown.

View File

@ -0,0 +1,52 @@
import torch
torch.manual_seed(987654)
n_tr = 7
n_val = 300
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
y_val = torch.sin(x_val)
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
class Data(torch.nn.Module):
def __init__(self):
super(Data, self).__init__()
self.register_buffer('x_val', x_val)
self.register_buffer('y_val', y_val)
self.register_buffer('x_train', x_train)
self.register_buffer('y_train', y_train)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(1, 10, bias = True)
self.l2 = torch.nn.Linear(10, 10, bias = True)
self.l3 = torch.nn.Linear(10, 1, bias = True)
def forward(self, x):
x = self.l1(x)
x = torch.relu(x)
x = self.l2(x)
x = torch.relu(x)
x = self.l3(x)
return x
class Loss(torch.nn.Module):
def __init__(self, target):
super(Loss, self).__init__()
self.register_buffer('target', target)
self.loss = torch.nn.MSELoss()
def forward(self, x):
return self.loss(x, self.target)
torch.jit.script(Data()).save('data.pt')
torch.jit.script(Net()).save('net.pt')
torch.jit.script(Loss(y_train)).save('loss.pt')

Binary file not shown.

Binary file not shown.