diff --git a/kmath-noa/README.md b/kmath-noa/README.md index fffa26524..a0d1ecb9f 100644 --- a/kmath-noa/README.md +++ b/kmath-noa/README.md @@ -65,8 +65,18 @@ NoaFloat { // Reconstruct tensor val tensorReg = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose(-2, -1)) + + // Serialise tensor for later + tensorReg.save("tensorReg.pt") } ``` +The saved tensor can be loaded in `C++` or in `python`: +```python +import torch +tensor_reg = list(torch.jit.load('tensorReg.pt').parameters())[0] +``` + + ### Automatic Differentiation The [AutoGrad](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html) engine is exposed: @@ -200,6 +210,9 @@ NoaFloat { // Compute the loss on validation dataset lossModule.forwardAssign(yPred, loss) println("Validation loss: $loss") + + // The model can be serialised in its current state + netModule.save("trained_net.pt") } ``` diff --git a/kmath-noa/build.gradle.kts b/kmath-noa/build.gradle.kts index 85933eb4b..24b857b3d 100644 --- a/kmath-noa/build.gradle.kts +++ b/kmath-noa/build.gradle.kts @@ -161,6 +161,8 @@ tasks["compileJava"].dependsOn(buildCpp) tasks { withType{ systemProperty("java.library.path", "$cppBuildDir/kmath") + //systemProperty("java.library.path", + // "${System.getProperty("user.home")}/devspace/noa/cmake-build-release/kmath") } } diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index 24e3c2b05..8255725e7 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -317,4 +317,16 @@ class JNoa { public static native void zeroGradAdamOptim(long adamOptHandle); public static native void swapTensors(long lhsHandle, long rhsHandle); + + public static native long loadTensorDouble(String path, int device); + + public static native long loadTensorFloat(String path, int device); + + public static native long loadTensorLong(String path, int device); + + public static native long loadTensorInt(String path, int device); + + public static native void saveTensor(long tensorHandle, String path); + + public static native void saveJitModule(long jitModuleHandle, String path); } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 2d880656b..01d87715e 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -141,6 +141,8 @@ protected constructor(protected val scope: NoaScope) : public abstract fun loadJitModule(path: String, device: Device = Device.CPU): NoaJitModule + public abstract fun loadTensor(path: String, device: Device = Device.CPU): TensorType + public fun NoaJitModule.forward(features: Tensor): TensorType = wrap(JNoa.forwardPass(jitModuleHandle, features.tensor.tensorHandle)) @@ -275,6 +277,7 @@ protected constructor(scope: NoaScope) : JNoa.qrTensor(tensor.tensorHandle, Q, R) return Pair(wrap(Q), wrap(R)) } + /** * this implementation satisfies `tensor = P dot L dot U` */ @@ -399,6 +402,9 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleDouble(path, device.toInt())) + + override fun loadTensor(path: String, device: Device): NoaDoubleTensor = + wrap(JNoa.loadTensorDouble(path, device.toInt())) } public sealed class NoaFloatAlgebra @@ -478,6 +484,8 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleFloat(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaFloatTensor = + wrap(JNoa.loadTensorFloat(path, device.toInt())) } public sealed class NoaLongAlgebra @@ -542,6 +550,9 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleLong(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaLongTensor = + wrap(JNoa.loadTensorLong(path, device.toInt())) + } public sealed class NoaIntAlgebra @@ -606,4 +617,7 @@ protected constructor(scope: NoaScope) : override fun loadJitModule(path: String, device: Device): NoaJitModule = NoaJitModule(scope, JNoa.loadJitModuleInt(path, device.toInt())) + override fun loadTensor(path: String, device: Device): NoaIntTensor = + wrap(JNoa.loadTensorInt(path, device.toInt())) + } diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt similarity index 86% rename from kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt rename to kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt index 402cdf012..7f585ebac 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/NoaJitModule.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/jit.kt @@ -14,4 +14,6 @@ public class NoaJitModule internal constructor(scope: NoaScope, internal val jitModuleHandle: JitModuleHandle) : NoaResource(scope){ override fun dispose(): Unit = JNoa.disposeJitModule(jitModuleHandle) + + public fun save(path: String): Unit = JNoa.saveJitModule(jitModuleHandle, path) } \ No newline at end of file diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt index 3c9508db2..3f0bb85a0 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt @@ -33,6 +33,8 @@ protected constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle)) + public fun save(path: String): Unit = JNoa.saveTensor(tensorHandle, path) + override fun toString(): String = JNoa.tensorToString(tensorHandle) @PerformancePitfall diff --git a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h index f058e859e..72828f6b9 100644 --- a/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h +++ b/kmath-noa/src/main/resources/space_kscience_kmath_noa_JNoa.h @@ -1199,6 +1199,54 @@ JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_zeroGradAdamOptim JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_swapTensors (JNIEnv *, jclass, jlong, jlong); +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorDouble + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorDouble + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorFloat + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorFloat + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorLong + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorLong + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: loadTensorInt + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_space_kscience_kmath_noa_JNoa_loadTensorInt + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: saveTensor + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveTensor + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: space_kscience_kmath_noa_JNoa + * Method: saveJitModule + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_space_kscience_kmath_noa_JNoa_saveJitModule + (JNIEnv *, jclass, jlong, jstring); + #ifdef __cplusplus } #endif diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt index 41da55ca2..a5a0387a0 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestTensor.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.noa +import java.io.File import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -37,7 +38,18 @@ internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) { assertEquals(tensor[intArrayOf(0)], 10) } +internal fun NoaFloat.testingTensorSerialisation(tensorPath: String, device: Device = Device.CPU){ + val tensor = copyFromArray(floatArrayOf(45.5f, 98.6f), intArrayOf(2), device) + tensor.save(tensorPath) + val loadedTensor = loadTensor(tensorPath, device) + assertTrue(tensor.copyToArray() contentEquals loadedTensor.copyToArray()) +} + class TestTensor { + + private val resources = File("").resolve("src/test/resources") + private val tensorPath = resources.resolve("tensor.pt").absolutePath + @Test fun testCopying() = NoaFloat { withCuda { device -> @@ -76,4 +88,11 @@ class TestTensor { } }!! + @Test + fun tensorSerialisation() = NoaFloat { + withCuda { device -> + testingTensorSerialisation(tensorPath, device) + } + }!! + } \ No newline at end of file diff --git a/kmath-noa/src/test/resources/tensor.pt b/kmath-noa/src/test/resources/tensor.pt new file mode 100644 index 000000000..e51bf9fed Binary files /dev/null and b/kmath-noa/src/test/resources/tensor.pt differ