From 524b1d80d1d1cde6ed78c339a1407583fa017afd Mon Sep 17 00:00:00 2001 From: rgrit91 Date: Tue, 12 Jan 2021 19:49:16 +0000 Subject: [PATCH] Copyless data transfer --- kmath-torch/ctorch/include/ctorch.h | 13 ++++-- kmath-torch/ctorch/include/utils.hh | 4 +- kmath-torch/ctorch/src/ctorch.cc | 35 +++++++++++----- .../TorchTensorAlgebra.kt | 40 +++++++++++++++---- .../kscience/kmath/torch/TestTorchTensor.kt | 19 ++++++++- 5 files changed, 87 insertions(+), 24 deletions(-) diff --git a/kmath-torch/ctorch/include/ctorch.h b/kmath-torch/ctorch/include/ctorch.h index 66bf41446..9dcbd8288 100644 --- a/kmath-torch/ctorch/include/ctorch.h +++ b/kmath-torch/ctorch/include/ctorch.h @@ -18,10 +18,15 @@ extern "C" void set_seed(int seed); - TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device); - TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device); - TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device); - TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device); + double *get_data_double(TorchTensorHandle tensor_handle); + float *get_data_float(TorchTensorHandle tensor_handle); + long *get_data_long(TorchTensorHandle tensor_handle); + int *get_data_int(TorchTensorHandle tensor_handle); + + TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy); + TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy); + TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy); + TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy); TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device); diff --git a/kmath-torch/ctorch/include/utils.hh b/kmath-torch/ctorch/include/utils.hh index 79de077da..b92244144 100644 --- a/kmath-torch/ctorch/include/utils.hh +++ b/kmath-torch/ctorch/include/utils.hh @@ -61,9 +61,9 @@ namespace ctorch } template - inline torch::Tensor copy_from_blob(Dtype *data, std::vector shape, torch::Device device) + inline torch::Tensor from_blob(Dtype *data, std::vector shape, torch::Device device, bool copy) { - return torch::from_blob(data, shape, dtype()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true); + return torch::from_blob(data, shape, dtype()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy); } template diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index 37b20811e..effaff6fe 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -25,6 +25,23 @@ void set_seed(int seed) torch::manual_seed(seed); } +double *get_data_double(TorchTensorHandle tensor_handle) +{ + return ctorch::cast(tensor_handle).data_ptr(); +} +float *get_data_float(TorchTensorHandle tensor_handle) +{ + return ctorch::cast(tensor_handle).data_ptr(); +} +long *get_data_long(TorchTensorHandle tensor_handle) +{ + return ctorch::cast(tensor_handle).data_ptr(); +} +int *get_data_int(TorchTensorHandle tensor_handle) +{ + return ctorch::cast(tensor_handle).data_ptr(); +} + int get_dim(TorchTensorHandle tensor_handle) { return ctorch::cast(tensor_handle).dim(); @@ -46,21 +63,21 @@ int get_device(TorchTensorHandle tensor_handle) return ctorch::device_to_int(ctorch::cast(tensor_handle)); } -TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device) +TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy) { - return new torch::Tensor(ctorch::copy_from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); + return new torch::Tensor(ctorch::from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy)); } -TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device) +TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy) { - return new torch::Tensor(ctorch::copy_from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); + return new torch::Tensor(ctorch::from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy)); } -TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device) +TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy) { - return new torch::Tensor(ctorch::copy_from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); + return new torch::Tensor(ctorch::from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy)); } -TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device) +TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy) { - return new torch::Tensor(ctorch::copy_from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); + return new torch::Tensor(ctorch::from_blob(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy)); } TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle) { @@ -245,7 +262,7 @@ TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle) } TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) { - return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j)); + return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j)); } bool requires_grad(TorchTensorHandle tensor_handle) diff --git a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt index 80df73765..f3771b414 100644 --- a/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeMain/kotlin/kscience.kmath.torch/TorchTensorAlgebra.kt @@ -3,17 +3,21 @@ package kscience.kmath.torch import kotlinx.cinterop.* import kscience.kmath.ctorch.* -public sealed class TorchTensorAlgebra constructor( +public sealed class TorchTensorAlgebra constructor( internal val scope: DeferScope ) { internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor + public abstract fun copyFromArray( array: PrimitiveArrayType, shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU ): TorchTensor - public abstract fun TorchTensor.copyToArray(): PrimitiveArrayType + + public abstract fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensor + public abstract fun TorchTensor.getData(): CPointer + public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor public abstract operator fun T.times(other: TorchTensor): TorchTensor @@ -56,13 +60,13 @@ public sealed class TorchTensorAlgebra constructor( wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) } -public sealed class TorchTensorFieldAlgebra(scope: DeferScope) : - TorchTensorAlgebra(scope) { +public sealed class TorchTensorFieldAlgebra(scope: DeferScope) : + TorchTensorAlgebra(scope) { public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor } -public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra(scope) { +public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra(scope) { override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = TorchTensorReal(scope = scope, tensorHandle = tensorHandle) @@ -76,14 +80,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra ): TorchTensorReal = TorchTensorReal( scope = scope, - tensorHandle = copy_from_blob_double( + tensorHandle = from_blob_double( array.toCValues(), shape.toCValues(), shape.size, - device.toInt() + device.toInt(), + true )!! ) + override fun fromBlob(arrayBlob: CPointer, shape: IntArray): TorchTensorReal = + TorchTensorReal( + scope = scope, + tensorHandle = from_blob_double( + arrayBlob, + shape.toCValues(), + shape.size, + TorchDevice.TorchCPU.toInt(), + false + )!! + ) + + override fun TorchTensor.getData(): CPointer { + require(this.device is TorchDevice.TorchCPU){ + "This tensor is not on available on CPU" + } + return get_data_double(this.tensorHandle)!! + } + override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( scope = scope, tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!! @@ -116,5 +140,5 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra } -public fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = +public inline fun TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = memScoped { TorchTensorRealAlgebra(this).block() } \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt index 5a6b07473..1dc3630c6 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensor.kt @@ -1,5 +1,6 @@ package kscience.kmath.torch +import kotlinx.cinterop.* import kotlin.test.* internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit { @@ -21,7 +22,23 @@ class TestTorchTensor { fun testCopyFromArray() = testingCopyFromArray() @Test - fun testRequiresGrad() = TorchTensorRealAlgebra { + fun testCopyLessDataTransferOnCPU() = memScoped { + val data = allocArray(1) + data[0] = 1.0 + TorchTensorRealAlgebra { + val tensor = fromBlob(data, intArrayOf(1)) + assertEquals(tensor[intArrayOf(0)], 1.0) + data[0] = 2.0 + assertEquals(tensor[intArrayOf(0)], 2.0) + val tensorData = tensor.getData() + tensorData[0] = 3.0 + println(assertEquals(tensor[intArrayOf(0)], 3.0)) + } + assertEquals(data[0], 3.0) + } + + @Test + fun testRequiresGrad() = TorchTensorRealAlgebra { val tensor = randNormal(intArrayOf(3)) assertTrue(!tensor.requiresGrad) tensor.requiresGrad = true