tensors and modules serialisation
This commit is contained in:
parent
686baa6517
commit
36bc127260
@ -65,8 +65,18 @@ NoaFloat {
|
||||
// Reconstruct tensor
|
||||
val tensorReg =
|
||||
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
||||
|
||||
// Serialise tensor for later
|
||||
tensorReg.save("tensorReg.pt")
|
||||
}
|
||||
```
|
||||
The saved tensor can be loaded in `C++` or in `python`:
|
||||
```python
|
||||
import torch
|
||||
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
||||
```
|
||||
|
||||
|
||||
### Automatic Differentiation
|
||||
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
||||
engine is exposed:
|
||||
@ -200,6 +210,9 @@ NoaFloat {
|
||||
// Compute the loss on validation dataset
|
||||
lossModule.forwardAssign(yPred, loss)
|
||||
println("Validation loss: $loss")
|
||||
|
||||
// The model can be serialised in its current state
|
||||
netModule.save("trained_net.pt")
|
||||
}
|
||||
|
||||
```
|
||||
|
@ -161,6 +161,8 @@ tasks["compileJava"].dependsOn(buildCpp)
|
||||
tasks {
|
||||
withType<Test>{
|
||||
systemProperty("java.library.path", "$cppBuildDir/kmath")
|
||||
//systemProperty("java.library.path",
|
||||
// "${System.getProperty("user.home")}/devspace/noa/cmake-build-release/kmath")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -317,4 +317,16 @@ class JNoa {
|
||||
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||
|
||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||
|
||||
public static native long loadTensorDouble(String path, int device);
|
||||
|
||||
public static native long loadTensorFloat(String path, int device);
|
||||
|
||||
public static native long loadTensorLong(String path, int device);
|
||||
|
||||
public static native long loadTensorInt(String path, int device);
|
||||
|
||||
public static native void saveTensor(long tensorHandle, String path);
|
||||
|
||||
public static native void saveJitModule(long jitModuleHandle, String path);
|
||||
}
|
||||
|
@ -141,6 +141,8 @@ protected constructor(protected val scope: NoaScope) :
|
||||
|
||||
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
||||
|
||||
public abstract fun loadTensor(path: String, device: Device = Device.CPU): TensorType
|
||||
|
||||
public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
|
||||
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
|
||||
|
||||
@ -275,6 +277,7 @@ protected constructor(scope: NoaScope) :
|
||||
JNoa.qrTensor(tensor.tensorHandle, Q, R)
|
||||
return Pair(wrap(Q), wrap(R))
|
||||
}
|
||||
|
||||
/**
|
||||
* this implementation satisfies `tensor = P dot L dot U`
|
||||
*/
|
||||
@ -399,6 +402,9 @@ protected constructor(scope: NoaScope) :
|
||||
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaDoubleTensor =
|
||||
wrap(JNoa.loadTensorDouble(path, device.toInt()))
|
||||
}
|
||||
|
||||
public sealed class NoaFloatAlgebra
|
||||
@ -478,6 +484,8 @@ protected constructor(scope: NoaScope) :
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaFloatTensor =
|
||||
wrap(JNoa.loadTensorFloat(path, device.toInt()))
|
||||
}
|
||||
|
||||
public sealed class NoaLongAlgebra
|
||||
@ -542,6 +550,9 @@ protected constructor(scope: NoaScope) :
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaLongTensor =
|
||||
wrap(JNoa.loadTensorLong(path, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
public sealed class NoaIntAlgebra
|
||||
@ -606,4 +617,7 @@ protected constructor(scope: NoaScope) :
|
||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
|
||||
|
||||
override fun loadTensor(path: String, device: Device): NoaIntTensor =
|
||||
wrap(JNoa.loadTensorInt(path, device.toInt()))
|
||||
|
||||
}
|
||||
|
@ -14,4 +14,6 @@ public class NoaJitModule
|
||||
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
|
||||
: NoaResource(scope){
|
||||
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
|
||||
|
||||
public fun save(path: String): Unit = JNoa.saveJitModule(jitModuleHandle, path)
|
||||
}
|
@ -33,6 +33,8 @@ protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle)
|
||||
|
||||
public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle))
|
||||
|
||||
public fun save(path: String): Unit = JNoa.saveTensor(tensorHandle, path)
|
||||
|
||||
override fun toString(): String = JNoa.tensorToString(tensorHandle)
|
||||
|
||||
@PerformancePitfall
|
||||
|
@ -1199,6 +1199,54 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
|
||||
(JNIEnv *, jclass, jlong, jlong);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: loadTensorDouble
|
||||
* Signature: (Ljava/lang/String;I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorDouble
|
||||
(JNIEnv *, jclass, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: loadTensorFloat
|
||||
* Signature: (Ljava/lang/String;I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorFloat
|
||||
(JNIEnv *, jclass, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: loadTensorLong
|
||||
* Signature: (Ljava/lang/String;I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorLong
|
||||
(JNIEnv *, jclass, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: loadTensorInt
|
||||
* Signature: (Ljava/lang/String;I)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorInt
|
||||
(JNIEnv *, jclass, jstring, jint);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: saveTensor
|
||||
* Signature: (JLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
/*
|
||||
* Class: space_kscience_kmath_noa_JNoa
|
||||
* Method: saveJitModule
|
||||
* Signature: (JLjava/lang/String;)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule
|
||||
(JNIEnv *, jclass, jlong, jstring);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
|
||||
import java.io.File
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
@ -37,7 +38,18 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||
assertEquals(tensor[intArrayOf(0)], 10)
|
||||
}
|
||||
|
||||
internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){
|
||||
val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device)
|
||||
tensor.save(tensorPath)
|
||||
val loadedTensor = loadTensor(tensorPath, device)
|
||||
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
||||
}
|
||||
|
||||
class TestTensor {
|
||||
|
||||
private val resources = File("").resolve("src/test/resources")
|
||||
private val tensorPath = resources.resolve("tensor.pt").absolutePath
|
||||
|
||||
@Test
|
||||
fun testCopying() = NoaFloat {
|
||||
withCuda { device ->
|
||||
@ -76,4 +88,11 @@ class TestTensor {
|
||||
}
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun tensorSerialisation() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingTensorSerialisation(tensorPath, device)
|
||||
}
|
||||
}!!
|
||||
|
||||
}
|
BIN
kmath-noa/src/test/resources/tensor.pt
Normal file
BIN
kmath-noa/src/test/resources/tensor.pt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user