forked from kscience/kmath
tensors and modules serialisation
This commit is contained in:
parent
686baa6517
commit
36bc127260
@ -65,8 +65,18 @@ NoaFloat {
|
|||||||
// Reconstruct tensor
|
// Reconstruct tensor
|
||||||
val tensorReg =
|
val tensorReg =
|
||||||
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1))
|
||||||
|
|
||||||
|
// Serialise tensor for later
|
||||||
|
tensorReg.save("tensorReg.pt")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
The saved tensor can be loaded in `C++` or in `python`:
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Automatic Differentiation
|
### Automatic Differentiation
|
||||||
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)
|
||||||
engine is exposed:
|
engine is exposed:
|
||||||
@ -200,6 +210,9 @@ NoaFloat {
|
|||||||
// Compute the loss on validation dataset
|
// Compute the loss on validation dataset
|
||||||
lossModule.forwardAssign(yPred, loss)
|
lossModule.forwardAssign(yPred, loss)
|
||||||
println("Validation loss: $loss")
|
println("Validation loss: $loss")
|
||||||
|
|
||||||
|
// The model can be serialised in its current state
|
||||||
|
netModule.save("trained_net.pt")
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -161,6 +161,8 @@ tasks["compileJava"].dependsOn(buildCpp)
|
|||||||
tasks {
|
tasks {
|
||||||
withType<Test>{
|
withType<Test>{
|
||||||
systemProperty("java.library.path", "$cppBuildDir/kmath")
|
systemProperty("java.library.path", "$cppBuildDir/kmath")
|
||||||
|
//systemProperty("java.library.path",
|
||||||
|
// "${System.getProperty("user.home")}/devspace/noa/cmake-build-release/kmath")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,4 +317,16 @@ class JNoa {
|
|||||||
public static native void zeroGradAdamOptim(long adamOptHandle);
|
public static native void zeroGradAdamOptim(long adamOptHandle);
|
||||||
|
|
||||||
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
public static native void swapTensors(long lhsHandle, long rhsHandle);
|
||||||
|
|
||||||
|
public static native long loadTensorDouble(String path, int device);
|
||||||
|
|
||||||
|
public static native long loadTensorFloat(String path, int device);
|
||||||
|
|
||||||
|
public static native long loadTensorLong(String path, int device);
|
||||||
|
|
||||||
|
public static native long loadTensorInt(String path, int device);
|
||||||
|
|
||||||
|
public static native void saveTensor(long tensorHandle, String path);
|
||||||
|
|
||||||
|
public static native void saveJitModule(long jitModuleHandle, String path);
|
||||||
}
|
}
|
||||||
|
@ -141,6 +141,8 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule
|
||||||
|
|
||||||
|
public abstract fun loadTensor(path: String, device: Device = Device.CPU): TensorType
|
||||||
|
|
||||||
public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
|
public fun NoaJitModule.forward(features: Tensor<T>): TensorType =
|
||||||
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
|
wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle))
|
||||||
|
|
||||||
@ -275,6 +277,7 @@ protected constructor(scope: NoaScope) :
|
|||||||
JNoa.qrTensor(tensor.tensorHandle, Q, R)
|
JNoa.qrTensor(tensor.tensorHandle, Q, R)
|
||||||
return Pair(wrap(Q), wrap(R))
|
return Pair(wrap(Q), wrap(R))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this implementation satisfies `tensor = P dot L dot U`
|
* this implementation satisfies `tensor = P dot L dot U`
|
||||||
*/
|
*/
|
||||||
@ -399,6 +402,9 @@ protected constructor(scope: NoaScope) :
|
|||||||
|
|
||||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||||
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
|
NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt()))
|
||||||
|
|
||||||
|
override fun loadTensor(path: String, device: Device): NoaDoubleTensor =
|
||||||
|
wrap(JNoa.loadTensorDouble(path, device.toInt()))
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaFloatAlgebra
|
public sealed class NoaFloatAlgebra
|
||||||
@ -478,6 +484,8 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||||
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
|
NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt()))
|
||||||
|
|
||||||
|
override fun loadTensor(path: String, device: Device): NoaFloatTensor =
|
||||||
|
wrap(JNoa.loadTensorFloat(path, device.toInt()))
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaLongAlgebra
|
public sealed class NoaLongAlgebra
|
||||||
@ -542,6 +550,9 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||||
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
|
NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt()))
|
||||||
|
|
||||||
|
override fun loadTensor(path: String, device: Device): NoaLongTensor =
|
||||||
|
wrap(JNoa.loadTensorLong(path, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaIntAlgebra
|
public sealed class NoaIntAlgebra
|
||||||
@ -606,4 +617,7 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
override fun loadJitModule(path: String, device: Device): NoaJitModule =
|
||||||
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
|
NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt()))
|
||||||
|
|
||||||
|
override fun loadTensor(path: String, device: Device): NoaIntTensor =
|
||||||
|
wrap(JNoa.loadTensorInt(path, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -14,4 +14,6 @@ public class NoaJitModule
|
|||||||
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
|
internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle)
|
||||||
: NoaResource(scope){
|
: NoaResource(scope){
|
||||||
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
|
override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle)
|
||||||
|
|
||||||
|
public fun save(path: String): Unit = JNoa.saveJitModule(jitModuleHandle, path)
|
||||||
}
|
}
|
@ -33,6 +33,8 @@ protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle)
|
|||||||
|
|
||||||
public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle))
|
public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle))
|
||||||
|
|
||||||
|
public fun save(path: String): Unit = JNoa.saveTensor(tensorHandle, path)
|
||||||
|
|
||||||
override fun toString(): String = JNoa.tensorToString(tensorHandle)
|
override fun toString(): String = JNoa.tensorToString(tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
|
@ -1199,6 +1199,54 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim
|
|||||||
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors
|
||||||
(JNIEnv *, jclass, jlong, jlong);
|
(JNIEnv *, jclass, jlong, jlong);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: loadTensorDouble
|
||||||
|
* Signature: (Ljava/lang/String;I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorDouble
|
||||||
|
(JNIEnv *, jclass, jstring, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: loadTensorFloat
|
||||||
|
* Signature: (Ljava/lang/String;I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorFloat
|
||||||
|
(JNIEnv *, jclass, jstring, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: loadTensorLong
|
||||||
|
* Signature: (Ljava/lang/String;I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorLong
|
||||||
|
(JNIEnv *, jclass, jstring, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: loadTensorInt
|
||||||
|
* Signature: (Ljava/lang/String;I)J
|
||||||
|
*/
|
||||||
|
JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorInt
|
||||||
|
(JNIEnv *, jclass, jstring, jint);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: saveTensor
|
||||||
|
* Signature: (JLjava/lang/String;)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor
|
||||||
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: space_kscience_kmath_noa_JNoa
|
||||||
|
* Method: saveJitModule
|
||||||
|
* Signature: (JLjava/lang/String;)V
|
||||||
|
*/
|
||||||
|
JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule
|
||||||
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.noa
|
package space.kscience.kmath.noa
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
@ -37,7 +38,18 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
|
|||||||
assertEquals(tensor[intArrayOf(0)], 10)
|
assertEquals(tensor[intArrayOf(0)], 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){
|
||||||
|
val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device)
|
||||||
|
tensor.save(tensorPath)
|
||||||
|
val loadedTensor = loadTensor(tensorPath, device)
|
||||||
|
assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray())
|
||||||
|
}
|
||||||
|
|
||||||
class TestTensor {
|
class TestTensor {
|
||||||
|
|
||||||
|
private val resources = File("").resolve("src/test/resources")
|
||||||
|
private val tensorPath = resources.resolve("tensor.pt").absolutePath
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCopying() = NoaFloat {
|
fun testCopying() = NoaFloat {
|
||||||
withCuda { device ->
|
withCuda { device ->
|
||||||
@ -76,4 +88,11 @@ class TestTensor {
|
|||||||
}
|
}
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun tensorSerialisation() = NoaFloat {
|
||||||
|
withCuda { device ->
|
||||||
|
testingTensorSerialisation(tensorPath, device)
|
||||||
|
}
|
||||||
|
}!!
|
||||||
|
|
||||||
}
|
}
|
BIN
kmath-noa/src/test/resources/tensor.pt
Normal file
BIN
kmath-noa/src/test/resources/tensor.pt
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user