forked from kscience/kmath
deep learning example
This commit is contained in:
parent
0bc2c12a05
commit
cccff29378
@ -45,6 +45,8 @@ To load the native library you will need to add to the VM options:
|
||||
|
||||
## Usage
|
||||
|
||||
###Linear Algebra
|
||||
|
||||
We implement the tensor algebra interfaces
|
||||
from [kmath-tensors](../kmath-tensors):
|
||||
```kotlin
|
||||
@ -62,10 +64,9 @@ NoaFloat {
|
||||
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
||||
}
|
||||
```
|
||||
|
||||
### Automatic Differentiation
|
||||
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
||||
engine is exposed:
|
||||
|
||||
```kotlin
|
||||
NoaFloat {
|
||||
// Create a quadratic function
|
||||
@ -88,11 +89,122 @@ NoaFloat {
|
||||
val hessianAtX = expressionAtX.autoHessian(tensorX)
|
||||
}
|
||||
```
|
||||
### Deep Learning
|
||||
You can train any [TorchScript](https://pytorch.org/docs/stable/jit.html) model.
|
||||
For example, you can build in `python` the following neural network
|
||||
and prepare the training data:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
n_tr = 7
|
||||
n_val = 300
|
||||
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
|
||||
y_val = torch.sin(x_val)
|
||||
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
|
||||
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
|
||||
|
||||
class Data(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Data, self).__init__()
|
||||
self.register_buffer('x_val', x_val)
|
||||
self.register_buffer('y_val', y_val)
|
||||
self.register_buffer('x_train', x_train)
|
||||
self.register_buffer('y_train', y_train)
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = torch.nn.Linear(1, 10, bias = True)
|
||||
self.l2 = torch.nn.Linear(10, 10, bias = True)
|
||||
self.l3 = torch.nn.Linear(10, 1, bias = True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.l1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.l2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.l3(x)
|
||||
return x
|
||||
|
||||
class Loss(torch.nn.Module):
|
||||
def __init__(self, target):
|
||||
super(Loss, self).__init__()
|
||||
self.register_buffer('target', target)
|
||||
self.loss = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, x):
|
||||
return self.loss(x, self.target)
|
||||
|
||||
# Generated TorchScript modules and serialise
|
||||
torch.jit.script(Data()).save('data.pt')
|
||||
torch.jit.script(Net()).save('net.pt')
|
||||
torch.jit.script(Loss(y_train)).save('loss.pt')
|
||||
```
|
||||
|
||||
You can then load the modules into `kotlin` and train them:
|
||||
```kotlin
|
||||
NoaFloat {
|
||||
|
||||
// Load the serialised JIT modules
|
||||
// The training data
|
||||
val dataModule = loadJitModule("data.pt")
|
||||
// The DL model
|
||||
val netModule = loadJitModule("net.pt")
|
||||
// The loss function
|
||||
val lossModule = loadJitModule("loss.pt")
|
||||
|
||||
// Get the tensors from the module
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
// Set the model in training mode
|
||||
netModule.train(true)
|
||||
// Loss function set for training regime
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
// Compute the predictions
|
||||
val yPred = netModule.forward(xTrain)
|
||||
// Compute the training error
|
||||
val loss = lossModule.forward(yPred)
|
||||
println(loss)
|
||||
|
||||
// Set-up the Adam optimiser with learning rate 0.005
|
||||
val optimiser = netModule.adamOptimiser(0.005)
|
||||
|
||||
// Train for 250 epochs
|
||||
repeat(250){
|
||||
// Clean gradients
|
||||
optimiser.zeroGrad()
|
||||
// Use forwardAssign to for better memory management
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
// Backward pass
|
||||
loss.backward()
|
||||
// Update model parameters
|
||||
optimiser.step()
|
||||
if(it % 50 == 0)
|
||||
println("Training loss: $loss")
|
||||
}
|
||||
|
||||
// Finally validate the model
|
||||
// Compute the predictions for the validation features
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
// Set the loss for the true values
|
||||
lossModule.setBuffer("target", yVal)
|
||||
// Compute the loss on validation dataset
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
println("Validation loss: $loss")
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Custom memory management
|
||||
Native memory management relies on scoping
|
||||
with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt)
|
||||
which is readily within an algebra context.
|
||||
which is readily available within an algebra context.
|
||||
Manual management is also possible:
|
||||
```kotlin
|
||||
// Create a scope
|
||||
@ -107,4 +219,5 @@ val tensor = NoaFloat(scope){
|
||||
scope.disposeAll()
|
||||
|
||||
// Attempts to use tensor here is undefined behaviour
|
||||
```
|
||||
```
|
||||
Contributed by [Roland Grinis](https://github.com/grinisrit)
|
||||
|
@ -160,8 +160,7 @@ tasks["compileJava"].dependsOn(buildCpp)
|
||||
|
||||
tasks {
|
||||
withType<Test>{
|
||||
systemProperty("java.library.path", //"$home/devspace/noa/cmake-build-release/kmath")
|
||||
"$cppBuildDir/kmath")
|
||||
systemProperty("java.library.path", "$cppBuildDir/kmath")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,7 +298,7 @@ class JNoa {
|
||||
|
||||
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
||||
|
||||
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
||||
public static native void forwardPassAssign(long jitModuleHandle, long featuresHandle, long predsHandle);
|
||||
|
||||
public static native long getModuleParameter(long jitModuleHandle, String name);
|
||||
|
||||
|
@ -141,11 +141,11 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
||||
|
||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
||||
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
||||
public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
|
||||
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
|
||||
|
||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
||||
JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle)
|
||||
public fun NoaJitModule.forwardAssign(features: TensorType, predictions: TensorType): Unit =
|
||||
JNoa.forwardPassAssign(jitModuleHandle, features.tensorHandle, predictions.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||
@ -154,7 +154,7 @@ protected constructor(protected val scope: NoaScope) :
|
||||
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||
|
||||
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||
wrap(JNoa.getModuleBuffer(jitModuleHandle, name))
|
||||
|
||||
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||
|
@ -1122,10 +1122,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: forwardPassAssign
|
||||
* Signature: (JJ)V
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
|
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import java.io.File
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class TestJitModules {
|
||||
|
||||
private val resources = File("").resolve("src/test/resources")
|
||||
private val dataPath = resources.resolve("data.pt").absolutePath
|
||||
private val netPath = resources.resolve("net.pt").absolutePath
|
||||
private val lossPath = resources.resolve("loss.pt").absolutePath
|
||||
|
||||
@Test
|
||||
fun testOptimisation() = NoaFloat {
|
||||
|
||||
setSeed(SEED)
|
||||
|
||||
val dataModule = loadJitModule(dataPath)
|
||||
val netModule = loadJitModule(netPath)
|
||||
val lossModule = loadJitModule(lossPath)
|
||||
|
||||
val xTrain = dataModule.getBuffer("x_train")
|
||||
val yTrain = dataModule.getBuffer("y_train")
|
||||
val xVal = dataModule.getBuffer("x_val")
|
||||
val yVal = dataModule.getBuffer("y_val")
|
||||
|
||||
netModule.train(true)
|
||||
lossModule.setBuffer("target", yTrain)
|
||||
|
||||
val yPred = netModule.forward(xTrain)
|
||||
val loss = lossModule.forward(yPred)
|
||||
val optimiser = netModule.adamOptimiser(0.005)
|
||||
|
||||
repeat(250){
|
||||
optimiser.zeroGrad()
|
||||
netModule.forwardAssign(xTrain, yPred)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
loss.backward()
|
||||
optimiser.step()
|
||||
}
|
||||
|
||||
netModule.forwardAssign(xVal, yPred)
|
||||
lossModule.setBuffer("target", yVal)
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
|
||||
assertTrue(loss.value() < 0.1)
|
||||
}!!
|
||||
|
||||
}
|
BIN
kmath-noa/src/test/resources/data.pt
Normal file
BIN
kmath-noa/src/test/resources/data.pt
Normal file
Binary file not shown.
52
kmath-noa/src/test/resources/jitscripts.py
Normal file
52
kmath-noa/src/test/resources/jitscripts.py
Normal file
@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
torch.manual_seed(987654)
|
||||
|
||||
n_tr = 7
|
||||
n_val = 300
|
||||
|
||||
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
|
||||
y_val = torch.sin(x_val)
|
||||
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
|
||||
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
|
||||
|
||||
|
||||
class Data(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Data, self).__init__()
|
||||
self.register_buffer('x_val', x_val)
|
||||
self.register_buffer('y_val', y_val)
|
||||
self.register_buffer('x_train', x_train)
|
||||
self.register_buffer('y_train', y_train)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.l1 = torch.nn.Linear(1, 10, bias = True)
|
||||
self.l2 = torch.nn.Linear(10, 10, bias = True)
|
||||
self.l3 = torch.nn.Linear(10, 1, bias = True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.l1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.l2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.l3(x)
|
||||
return x
|
||||
|
||||
class Loss(torch.nn.Module):
|
||||
def __init__(self, target):
|
||||
super(Loss, self).__init__()
|
||||
self.register_buffer('target', target)
|
||||
self.loss = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, x):
|
||||
return self.loss(x, self.target)
|
||||
|
||||
|
||||
torch.jit.script(Data()).save('data.pt')
|
||||
|
||||
torch.jit.script(Net()).save('net.pt')
|
||||
|
||||
torch.jit.script(Loss(y_train)).save('loss.pt')
|
BIN
kmath-noa/src/test/resources/loss.pt
Normal file
BIN
kmath-noa/src/test/resources/loss.pt
Normal file
Binary file not shown.
BIN
kmath-noa/src/test/resources/net.pt
Normal file
BIN
kmath-noa/src/test/resources/net.pt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user