forked from kscience/kmath
deep learning example
This commit is contained in:
parent
0bc2c12a05
commit
cccff29378
@ -45,6 +45,8 @@ To load the native library you will need to add to the VM options:
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
###Linear Algebra
|
||||||
|
|
||||||
We implement the tensor algebra interfaces
|
We implement the tensor algebra interfaces
|
||||||
from [kmath-tensors](../kmath-tensors):
|
from [kmath-tensors](../kmath-tensors):
|
||||||
```kotlin
|
```kotlin
|
||||||
@ -62,10 +64,9 @@ NoaFloat {
|
|||||||
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
### Automatic Differentiation
|
||||||
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
||||||
engine is exposed:
|
engine is exposed:
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
NoaFloat {
|
NoaFloat {
|
||||||
// Create a quadratic function
|
// Create a quadratic function
|
||||||
@ -88,11 +89,122 @@ NoaFloat {
|
|||||||
val hessianAtX = expressionAtX.autoHessian(tensorX)
|
val hessianAtX = expressionAtX.autoHessian(tensorX)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
### Deep Learning
|
||||||
|
You can train any [TorchScript](https://pytorch.org/docs/stable/jit.html) model.
|
||||||
|
For example, you can build in `python` the following neural network
|
||||||
|
and prepare the training data:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
|
||||||
|
n_tr = 7
|
||||||
|
n_val = 300
|
||||||
|
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
|
||||||
|
y_val = torch.sin(x_val)
|
||||||
|
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
|
||||||
|
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
|
||||||
|
|
||||||
|
class Data(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Data, self).__init__()
|
||||||
|
self.register_buffer('x_val', x_val)
|
||||||
|
self.register_buffer('y_val', y_val)
|
||||||
|
self.register_buffer('x_train', x_train)
|
||||||
|
self.register_buffer('y_train', y_train)
|
||||||
|
|
||||||
|
class Net(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.l1 = torch.nn.Linear(1, 10, bias = True)
|
||||||
|
self.l2 = torch.nn.Linear(10, 10, bias = True)
|
||||||
|
self.l3 = torch.nn.Linear(10, 1, bias = True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.l1(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.l2(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.l3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Loss(torch.nn.Module):
|
||||||
|
def __init__(self, target):
|
||||||
|
super(Loss, self).__init__()
|
||||||
|
self.register_buffer('target', target)
|
||||||
|
self.loss = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.loss(x, self.target)
|
||||||
|
|
||||||
|
# Generated TorchScript modules and serialise
|
||||||
|
torch.jit.script(Data()).save('data.pt')
|
||||||
|
torch.jit.script(Net()).save('net.pt')
|
||||||
|
torch.jit.script(Loss(y_train)).save('loss.pt')
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then load the modules into `kotlin` and train them:
|
||||||
|
```kotlin
|
||||||
|
NoaFloat {
|
||||||
|
|
||||||
|
// Load the serialised JIT modules
|
||||||
|
// The training data
|
||||||
|
val dataModule = loadJitModule("data.pt")
|
||||||
|
// The DL model
|
||||||
|
val netModule = loadJitModule("net.pt")
|
||||||
|
// The loss function
|
||||||
|
val lossModule = loadJitModule("loss.pt")
|
||||||
|
|
||||||
|
// Get the tensors from the module
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
// Set the model in training mode
|
||||||
|
netModule.train(true)
|
||||||
|
// Loss function set for training regime
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
// Compute the predictions
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
// Compute the training error
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
println(loss)
|
||||||
|
|
||||||
|
// Set-up the Adam optimiser with learning rate 0.005
|
||||||
|
val optimiser = netModule.adamOptimiser(0.005)
|
||||||
|
|
||||||
|
// Train for 250 epochs
|
||||||
|
repeat(250){
|
||||||
|
// Clean gradients
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
// Use forwardAssign to for better memory management
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
// Backward pass
|
||||||
|
loss.backward()
|
||||||
|
// Update model parameters
|
||||||
|
optimiser.step()
|
||||||
|
if(it % 50 == 0)
|
||||||
|
println("Training loss: $loss")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally validate the model
|
||||||
|
// Compute the predictions for the validation features
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
// Set the loss for the true values
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
// Compute the loss on validation dataset
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
println("Validation loss: $loss")
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom memory management
|
||||||
Native memory management relies on scoping
|
Native memory management relies on scoping
|
||||||
with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt)
|
with [NoaScope](src/main/kotlin/space/kscience/kmath/noa/memory/NoaScope.kt)
|
||||||
which is readily within an algebra context.
|
which is readily available within an algebra context.
|
||||||
Manual management is also possible:
|
Manual management is also possible:
|
||||||
```kotlin
|
```kotlin
|
||||||
// Create a scope
|
// Create a scope
|
||||||
@ -107,4 +219,5 @@ val tensor = NoaFloat(scope){
|
|||||||
scope.disposeAll()
|
scope.disposeAll()
|
||||||
|
|
||||||
// Attempts to use tensor here is undefined behaviour
|
// Attempts to use tensor here is undefined behaviour
|
||||||
```
|
```
|
||||||
|
Contributed by [Roland Grinis](https://github.com/grinisrit)
|
||||||
|
@ -160,8 +160,7 @@ tasks["compileJava"].dependsOn(buildCpp)
|
|||||||
|
|
||||||
tasks {
|
tasks {
|
||||||
withType<Test>{
|
withType<Test>{
|
||||||
systemProperty("java.library.path", //"$home/devspace/noa/cmake-build-release/kmath")
|
systemProperty("java.library.path", "$cppBuildDir/kmath")
|
||||||
"$cppBuildDir/kmath")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,7 +298,7 @@ class JNoa {
|
|||||||
|
|
||||||
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
public static native long forwardPass(long jitModuleHandle, long tensorHandle);
|
||||||
|
|
||||||
public static native void forwardPassAssign(long jitModuleHandle, long tensorHandle);
|
public static native void forwardPassAssign(long jitModuleHandle, long featuresHandle, long predsHandle);
|
||||||
|
|
||||||
public static native long getModuleParameter(long jitModuleHandle, String name);
|
public static native long getModuleParameter(long jitModuleHandle, String name);
|
||||||
|
|
||||||
|
@ -141,11 +141,11 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
||||||
|
|
||||||
public fun NoaJitModule.forward(parameters: Tensor<T>): TensorType =
|
public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
|
||||||
wrap(JNoa.forwardPass(jitModuleHandle, parameters.tensor.tensorHandle))
|
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
|
||||||
|
|
||||||
public fun NoaJitModule.forwardAssign(parameters: TensorType): Unit =
|
public fun NoaJitModule.forwardAssign(features: TensorType, predictions: TensorType): Unit =
|
||||||
JNoa.forwardPassAssign(jitModuleHandle, parameters.tensorHandle)
|
JNoa.forwardPassAssign(jitModuleHandle, features.tensorHandle, predictions.tensorHandle)
|
||||||
|
|
||||||
public fun NoaJitModule.getParameter(name: String): TensorType =
|
public fun NoaJitModule.getParameter(name: String): TensorType =
|
||||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
||||||
@ -154,7 +154,7 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
|
JNoa.setModuleParameter(jitModuleHandle, name, parameter.tensor.tensorHandle)
|
||||||
|
|
||||||
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
public fun NoaJitModule.getBuffer(name: String): TensorType =
|
||||||
wrap(JNoa.getModuleParameter(jitModuleHandle, name))
|
wrap(JNoa.getModuleBuffer(jitModuleHandle, name))
|
||||||
|
|
||||||
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
public fun NoaJitModule.setBuffer(name: String, buffer: Tensor<T>): Unit =
|
||||||
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
|
JNoa.setModuleBuffer(jitModuleHandle, name, buffer.tensor.tensorHandle)
|
||||||
|
@ -1122,10 +1122,10 @@ JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPass
|
|||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
* Method: forwardPassAssign
|
* Method: forwardPassAssign
|
||||||
* Signature: (JJ)V
|
* Signature: (JJJ)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_forwardPassAssign
|
||||||
(JNIEnv *, jclass, jlong, jlong);
|
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: space_kscience_kmath_noa_JNoa
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.noa
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
class TestJitModules {
|
||||||
|
|
||||||
|
private val resources = File("").resolve("src/test/resources")
|
||||||
|
private val dataPath = resources.resolve("data.pt").absolutePath
|
||||||
|
private val netPath = resources.resolve("net.pt").absolutePath
|
||||||
|
private val lossPath = resources.resolve("loss.pt").absolutePath
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOptimisation() = NoaFloat {
|
||||||
|
|
||||||
|
setSeed(SEED)
|
||||||
|
|
||||||
|
val dataModule = loadJitModule(dataPath)
|
||||||
|
val netModule = loadJitModule(netPath)
|
||||||
|
val lossModule = loadJitModule(lossPath)
|
||||||
|
|
||||||
|
val xTrain = dataModule.getBuffer("x_train")
|
||||||
|
val yTrain = dataModule.getBuffer("y_train")
|
||||||
|
val xVal = dataModule.getBuffer("x_val")
|
||||||
|
val yVal = dataModule.getBuffer("y_val")
|
||||||
|
|
||||||
|
netModule.train(true)
|
||||||
|
lossModule.setBuffer("target", yTrain)
|
||||||
|
|
||||||
|
val yPred = netModule.forward(xTrain)
|
||||||
|
val loss = lossModule.forward(yPred)
|
||||||
|
val optimiser = netModule.adamOptimiser(0.005)
|
||||||
|
|
||||||
|
repeat(250){
|
||||||
|
optimiser.zeroGrad()
|
||||||
|
netModule.forwardAssign(xTrain, yPred)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
loss.backward()
|
||||||
|
optimiser.step()
|
||||||
|
}
|
||||||
|
|
||||||
|
netModule.forwardAssign(xVal, yPred)
|
||||||
|
lossModule.setBuffer("target", yVal)
|
||||||
|
lossModule.forwardAssign(yPred, loss)
|
||||||
|
|
||||||
|
assertTrue(loss.value() < 0.1)
|
||||||
|
}!!
|
||||||
|
|
||||||
|
}
|
BIN
kmath-noa/src/test/resources/data.pt
Normal file
BIN
kmath-noa/src/test/resources/data.pt
Normal file
Binary file not shown.
52
kmath-noa/src/test/resources/jitscripts.py
Normal file
52
kmath-noa/src/test/resources/jitscripts.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
torch.manual_seed(987654)
|
||||||
|
|
||||||
|
n_tr = 7
|
||||||
|
n_val = 300
|
||||||
|
|
||||||
|
x_val = torch.linspace(-5, 5, n_val).view(-1, 1)
|
||||||
|
y_val = torch.sin(x_val)
|
||||||
|
x_train = torch.linspace(-3.14, 3.14, n_tr).view(-1, 1)
|
||||||
|
y_train = torch.sin(x_train) + torch.randn_like(x_train) * 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class Data(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Data, self).__init__()
|
||||||
|
self.register_buffer('x_val', x_val)
|
||||||
|
self.register_buffer('y_val', y_val)
|
||||||
|
self.register_buffer('x_train', x_train)
|
||||||
|
self.register_buffer('y_train', y_train)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.l1 = torch.nn.Linear(1, 10, bias = True)
|
||||||
|
self.l2 = torch.nn.Linear(10, 10, bias = True)
|
||||||
|
self.l3 = torch.nn.Linear(10, 1, bias = True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.l1(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.l2(x)
|
||||||
|
x = torch.relu(x)
|
||||||
|
x = self.l3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Loss(torch.nn.Module):
|
||||||
|
def __init__(self, target):
|
||||||
|
super(Loss, self).__init__()
|
||||||
|
self.register_buffer('target', target)
|
||||||
|
self.loss = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.loss(x, self.target)
|
||||||
|
|
||||||
|
|
||||||
|
torch.jit.script(Data()).save('data.pt')
|
||||||
|
|
||||||
|
torch.jit.script(Net()).save('net.pt')
|
||||||
|
|
||||||
|
torch.jit.script(Loss(y_train)).save('loss.pt')
|
BIN
kmath-noa/src/test/resources/loss.pt
Normal file
BIN
kmath-noa/src/test/resources/loss.pt
Normal file
Binary file not shown.
BIN
kmath-noa/src/test/resources/net.pt
Normal file
BIN
kmath-noa/src/test/resources/net.pt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user