tensors and modules serialisation

This commit is contained in:
Roland Grinis 2021-07-30 22:33:43 +01:00
parent 686baa6517
commit 36bc127260
9 changed files with 112 additions and 0 deletions

View File

@ -65,8 +65,18 @@ NoaFloat {
// Reconstruct tensor // Reconstruct tensor
val tensorReg = val tensorReg =
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)) tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
// Serialise tensor for later
tensorReg.save("tensorReg.pt")
} }
``` ```
The saved tensor can be loaded in `C++` or in `python`:
```python
import torch
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
```
### Automatic Differentiation ### Automatic Differentiation
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
engine is exposed: engine is exposed:
@ -200,6 +210,9 @@ NoaFloat {
// Compute the loss on validation dataset // Compute the loss on validation dataset
lossModule.forwardAssign(yPred, loss) lossModule.forwardAssign(yPred, loss)
println("Validation loss: $loss") println("Validation loss: $loss")
// The model can be serialised in its current state
netModule.save("trained_net.pt")
} }
``` ```

View File

@ -161,6 +161,8 @@ tasks["compileJava"].dependsOn(buildCpp)
tasks { tasks {
withType<Test>{ withType<Test>{
systemProperty("java.library.path", "$cppBuildDir/kmath") systemProperty("java.library.path", "$cppBuildDir/kmath")
//systemProperty("java.library.path",
// "${System.getProperty("user.home")}/devspace/noa/cmake-build-release/kmath")
} }
} }

View File

@ -317,4 +317,16 @@ class JNoa {
public static native void zeroGradAdamOptim(long adamOptHandle); public static native void zeroGradAdamOptim(long adamOptHandle);
public static native void swapTensors(long lhsHandle, long rhsHandle); public static native void swapTensors(long lhsHandle, long rhsHandle);
public static native long loadTensorDouble(String path, int device);
public static native long loadTensorFloat(String path, int device);
public static native long loadTensorLong(String path, int device);
public static native long loadTensorInt(String path, int device);
public static native void saveTensor(long tensorHandle, String path);
public static native void saveJitModule(long jitModuleHandle, String path);
} }

View File

@ -141,6 +141,8 @@ protected constructor(protected val scope: NoaScope) :
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
public abstract fun loadTensor(path: String, device: Device = Device.CPU): TensorType
public fun NoaJitModule.forward(features: Tensor<T>): TensorType = public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle)) wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
@ -275,6 +277,7 @@ protected constructor(scope: NoaScope) :
JNoa.qrTensor(tensor.tensorHandle, Q, R) JNoa.qrTensor(tensor.tensorHandle, Q, R)
return Pair(wrap(Q), wrap(R)) return Pair(wrap(Q), wrap(R))
} }
/** /**
* this implementation satisfies `tensor = P dot L dot U` * this implementation satisfies `tensor = P dot L dot U`
*/ */
@ -399,6 +402,9 @@ protected constructor(scope: NoaScope) :
override fun loadJitModule(path: String, device: Device): NoaJitModule = override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt())) NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
override fun loadTensor(path: String, device: Device): NoaDoubleTensor =
wrap(JNoa.loadTensorDouble(path, device.toInt()))
} }
public sealed class NoaFloatAlgebra public sealed class NoaFloatAlgebra
@ -478,6 +484,8 @@ protected constructor(scope: NoaScope) :
override fun loadJitModule(path: String, device: Device): NoaJitModule = override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt())) NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
override fun loadTensor(path: String, device: Device): NoaFloatTensor =
wrap(JNoa.loadTensorFloat(path, device.toInt()))
} }
public sealed class NoaLongAlgebra public sealed class NoaLongAlgebra
@ -542,6 +550,9 @@ protected constructor(scope: NoaScope) :
override fun loadJitModule(path: String, device: Device): NoaJitModule = override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt())) NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
override fun loadTensor(path: String, device: Device): NoaLongTensor =
wrap(JNoa.loadTensorLong(path, device.toInt()))
} }
public sealed class NoaIntAlgebra public sealed class NoaIntAlgebra
@ -606,4 +617,7 @@ protected constructor(scope: NoaScope) :
override fun loadJitModule(path: String, device: Device): NoaJitModule = override fun loadJitModule(path: String, device: Device): NoaJitModule =
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt())) NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
override fun loadTensor(path: String, device: Device): NoaIntTensor =
wrap(JNoa.loadTensorInt(path, device.toInt()))
} }

View File

@ -14,4 +14,6 @@ public class NoaJitModule
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle) internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
: NoaResource(scope){ : NoaResource(scope){
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle) override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
public fun save(path: String): Unit = JNoa.saveJitModule(jitModuleHandle, path)
} }

View File

@ -33,6 +33,8 @@ protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle)
public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle)) public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle))
public fun save(path: String): Unit = JNoa.saveTensor(tensorHandle, path)
override fun toString(): String = JNoa.tensorToString(tensorHandle) override fun toString(): String = JNoa.tensorToString(tensorHandle)
@PerformancePitfall @PerformancePitfall

View File

@ -1199,6 +1199,54 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
(JNIEnv *, jclass, jlong, jlong); (JNIEnv *, jclass, jlong, jlong);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: loadTensorDouble
* Signature: (Ljava/lang/String;I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorDouble
(JNIEnv *, jclass, jstring, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: loadTensorFloat
* Signature: (Ljava/lang/String;I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorFloat
(JNIEnv *, jclass, jstring, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: loadTensorLong
* Signature: (Ljava/lang/String;I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorLong
(JNIEnv *, jclass, jstring, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: loadTensorInt
* Signature: (Ljava/lang/String;I)J
*/
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorInt
(JNIEnv *, jclass, jstring, jint);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: saveTensor
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: space_kscience_kmath_noa_JNoa
* Method: saveJitModule
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule
(JNIEnv *, jclass, jlong, jstring);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import java.io.File
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -37,7 +38,18 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
assertEquals(tensor[intArrayOf(0)], 10) assertEquals(tensor[intArrayOf(0)], 10)
} }
internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){
val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device)
tensor.save(tensorPath)
val loadedTensor = loadTensor(tensorPath, device)
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
}
class TestTensor { class TestTensor {
private val resources = File("").resolve("src/test/resources")
private val tensorPath = resources.resolve("tensor.pt").absolutePath
@Test @Test
fun testCopying() = NoaFloat { fun testCopying() = NoaFloat {
withCuda { device -> withCuda { device ->
@ -76,4 +88,11 @@ class TestTensor {
} }
}!! }!!
@Test
fun tensorSerialisation() = NoaFloat {
withCuda { device ->
testingTensorSerialisation(tensorPath, device)
}
}!!
} }

Binary file not shown.