From 36bc1272606cc7ec1f0733b915a33dcc122c488c Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Fri, 30 Jul 2021 22:33:43 +0100 Subject: [PATCH] tensors and modules serialisation --- kmath-noa/README.md | 13 +++++ kmath-noa/build.gradle.kts | 2 + .../java/space/kscience/kmath/noa/JNoa.java | 12 +++++ .../space/kscience/kmath/noa/algebras.kt | 14 +++++ .../kmath/noa/{NoaJitModule.kt => jit.kt} | 2 + .../space/kscience/kmath/noa/tensors.kt | 2 + .../resources/space_kscience_kmath_noa_JNoa.h | 48 ++++++++++++++++++ .../space/kscience/kmath/noa/TestTensor.kt | 19 +++++++ kmath-noa/src/test/resources/tensor.pt | Bin 0 -> 1472 bytes 9 files changed, 112 insertions(+) rename kmath-noa/src/main/kotlin/space/kscience/kmath/noa/{NoaJitModule.kt => jit.kt} (86%) create mode 100644 kmath-noa/src/test/resources/tensor.pt diff --git a/kmath-noa/README.md b/kmath-noa/README.md index fffa26524..a0d1ecb9f 100644 --- a/kmath-noa/README.md +++ b/kmath-noa/README.md @@ -65,8 +65,18 @@ NoaFloat { // Reconstruct tensor val tensorReg = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)) + + // Serialise tensor for later + tensorReg.save("tensorReg.pt") } ``` +The saved tensor can be loaded in `C++` or in `python`: +```python +import torch +tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0] +``` + + ### Automatic Differentiation The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) engine is exposed: @@ -200,6 +210,9 @@ NoaFloat { // Compute the loss on validation dataset lossModule.forwardAssign(yPred, loss) println("Validation loss: $loss") + + // The model can be serialised in its current state + netModule.save("trained_net.pt") } ``` diff --git a/kmath-noa/build.gradle.kts b/kmath-noa/build.gradle.kts index 85933eb4b..24b857b3d 100644 --- a/kmath-noa/build.gradle.kts +++ b/kmath-noa/build.gradle.kts @@ -161,6 +161,8 @@ tasks["compileJava"].dependsOn(buildCpp) tasks { withType{ systemProperty("java.library.path", "$cppBuildDir/kmath") + //systemProperty("java.library.path", + // "${System.getProperty("user.home")}/devspace/noa/cmake-build-release/kmath") } } diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 24e3c2b05..8255725e7 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -317,4 +317,16 @@ class JNoa { public static native void zeroGradAdamOptim(long adamOptHandle); public static native void swapTensors(long lhsHandle, long rhsHandle); + + public static native long loadTensorDouble(String path, int device); + + public static native long loadTensorFloat(String path, int device); + + public static native long loadTensorLong(String path, int device); + + public static native long loadTensorInt(String path, int device); + + public static native void saveTensor(long tensorHandle, String path); + + public static native void saveJitModule(long jitModuleHandle, String path); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 2d880656b..01d87715e 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -141,6 +141,8 @@ protected constructor(protected val scope: NoaScope) : public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule + public abstract fun loadTensor(path: String, device: Device = Device.CPU): TensorType + public fun NoaJitModule.forward(features: Tensor): TensorType = wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle)) @@ -275,6 +277,7 @@ protected constructor(scope: NoaScope) : JNoa.qrTensor(tensor.tensorHandle, Q, R) return Pair(wrap(Q), wrap(R)) } + /** * this implementation satisfies `tensor = P dot L dot U` */ @@ -399,6 +402,9 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt())) + + override fun loadTensor(path: String, device: Device): NoaDoubleTensor = + wrap(JNoa.loadTensorDouble(path, device.toInt())) } public sealed class NoaFloatAlgebra @@ -478,6 +484,8 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaFloatTensor = + wrap(JNoa.loadTensorFloat(path, device.toInt())) } public sealed class NoaLongAlgebra @@ -542,6 +550,9 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaLongTensor = + wrap(JNoa.loadTensorLong(path, device.toInt())) + } public sealed class NoaIntAlgebra @@ -606,4 +617,7 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaIntTensor = + wrap(JNoa.loadTensorInt(path, device.toInt())) + } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt similarity index 86% rename from kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt rename to kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt index 402cdf012..7f585ebac 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt @@ -14,4 +14,6 @@ public class NoaJitModule internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle) : NoaResource(scope){ override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle) + + public fun save(path: String): Unit = JNoa.saveJitModule(jitModuleHandle, path) } \ No newline at end of file diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt index 3c9508db2..3f0bb85a0 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt @@ -33,6 +33,8 @@ protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle)) + public fun save(path: String): Unit = JNoa.saveTensor(tensorHandle, path) + override fun toString(): String = JNoa.tensorToString(tensorHandle) @PerformancePitfall diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index f058e859e..72828f6b9 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1199,6 +1199,54 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors (JNIEnv *, jclass, jlong, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorDouble + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorDouble + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorFloat + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorFloat + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorLong + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorLong + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorInt + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorInt + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: saveTensor + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: saveJitModule + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule + (JNIEnv *, jclass, jlong, jstring); + #ifdef __cplusplus } #endif diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt index 41da55ca2..a5a0387a0 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.noa +import java.io.File import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -37,7 +38,18 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) { assertEquals(tensor[intArrayOf(0)], 10) } +internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){ + val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device) + tensor.save(tensorPath) + val loadedTensor = loadTensor(tensorPath, device) + assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray()) +} + class TestTensor { + + private val resources = File("").resolve("src/test/resources") + private val tensorPath = resources.resolve("tensor.pt").absolutePath + @Test fun testCopying() = NoaFloat { withCuda { device -> @@ -76,4 +88,11 @@ class TestTensor { } }!! + @Test + fun tensorSerialisation() = NoaFloat { + withCuda { device -> + testingTensorSerialisation(tensorPath, device) + } + }!! + } \ No newline at end of file diff --git a/kmath-noa/src/test/resources/tensor.pt b/kmath-noa/src/test/resources/tensor.pt new file mode 100644 index 0000000000000000000000000000000000000000..e51bf9fed9bba5beb9018caaa482ca348af06a56 GIT binary patch literal 1472 zcmb7E-HOvd6h29}e}UaaSYa<@2?ZgAwe4% zFGLid!3%G^_QrcJd;om_AHipECh1R-HVYnlI_;TrzBxbNjE)Lx0F??X$vRvFYP*5w zH!Xu2&6cKKfG1f2xTdz-KUE!-D?e`kdCP&g>bba^a3hJ^7{72d_4+b$_N3MSC!&cl z^?Y-Hv2@?FLdTW})X!ckU4f<1B7zz1HgHIXP9S05?uA3g!f|Q%bYCK(qIAp5Ibfbf z-zJ4PMA973pbulCMAimEx5T{6&@v8Mq=-cNfRqp@?1zE``oT0k$FWU1^xQx?^ex-B zt-C{$N~ByrqDQ0>_DHqC8#Q`!`QA4M#WlZC+fbb~%JeL|Ij^%ZeyOR&xMQYc1c4kE zuEb0IK$2yQ$A)i=Y-;-f#_}!sX{}Y;m5-yXUU+gcx5p;Kbvz%@U`Esz%}x&bv@8pPDRbZ$};u^x}KS$)hd5^&8wWdrWaPB~=*U`)}}o9E)M zX8pq2j>hoobN`XZPUfm3fiLqroxGLnd?O{P8=ehK@^mv5##^2U<(?}%hdbJb^VpZm$EJ&5b8zttnmBRa*#7i(Bi;(Q~Mhc-*m}fMdu8vBj-Koz` q;mX{bzav~`>nHwJG5O69xUe|>Ai=XD{|4aOIx0r)B^Jf~W8Z%s7Dhq< literal 0 HcmV?d00001