Copyless data transfer

This commit is contained in:
rgrit91 2021-01-12 19:49:16 +00:00
parent 8967691b7d
commit 524b1d80d1
5 changed files with 87 additions and 24 deletions

View File

@ -18,10 +18,15 @@ extern "C"
void set_seed(int seed);
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device);
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device);
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device);
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device);
double *get_data_double(TorchTensorHandle tensor_handle);
float *get_data_float(TorchTensorHandle tensor_handle);
long *get_data_long(TorchTensorHandle tensor_handle);
int *get_data_int(TorchTensorHandle tensor_handle);
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);

View File

@ -61,9 +61,9 @@ namespace ctorch
}
template <typename Dtype>
inline torch::Tensor copy_from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device)
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
{
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true);
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
}
template <typename NumType>

View File

@ -25,6 +25,23 @@ void set_seed(int seed)
torch::manual_seed(seed);
}
double *get_data_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<double>();
}
float *get_data_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<float>();
}
long *get_data_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<long>();
}
int *get_data_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<int>();
}
int get_dim(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).dim();
@ -46,21 +63,21 @@ int get_device(TorchTensorHandle tensor_handle)
return ctorch::device_to_int(ctorch::cast(tensor_handle));
}
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device)
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
{
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
}
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device)
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
{
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
}
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device)
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
{
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
}
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device)
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
{
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
}
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
{
@ -245,7 +262,7 @@ TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
}
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
{
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j));
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
}
bool requires_grad(TorchTensorHandle tensor_handle)

View File

@ -3,17 +3,21 @@ package kscience.kmath.torch
import kotlinx.cinterop.*
import kscience.kmath.ctorch.*
public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
internal val scope: DeferScope
) {
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
public abstract fun copyFromArray(
array: PrimitiveArrayType,
shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU
): TorchTensor<T>
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
@ -56,13 +60,13 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
}
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
TorchTensorAlgebra<T, PrimitiveArrayType>(scope) {
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
}
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleArray>(scope) {
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
@ -76,14 +80,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
): TorchTensorReal =
TorchTensorReal(
scope = scope,
tensorHandle = copy_from_blob_double(
tensorHandle = from_blob_double(
array.toCValues(),
shape.toCValues(),
shape.size,
device.toInt()
device.toInt(),
true
)!!
)
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
TorchTensorReal(
scope = scope,
tensorHandle = from_blob_double(
arrayBlob,
shape.toCValues(),
shape.size,
TorchDevice.TorchCPU.toInt(),
false
)!!
)
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU){
"This tensor is not on available on CPU"
}
return get_data_double(this.tensorHandle)!!
}
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
scope = scope,
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
@ -116,5 +140,5 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
}
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
memScoped { TorchTensorRealAlgebra(this).block() }

View File

@ -1,5 +1,6 @@
package kscience.kmath.torch
import kotlinx.cinterop.*
import kotlin.test.*
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit {
@ -20,6 +21,22 @@ class TestTorchTensor {
@Test
fun testCopyFromArray() = testingCopyFromArray()
@Test
fun testCopyLessDataTransferOnCPU() = memScoped {
val data = allocArray<DoubleVar>(1)
data[0] = 1.0
TorchTensorRealAlgebra {
val tensor = fromBlob(data, intArrayOf(1))
assertEquals(tensor[intArrayOf(0)], 1.0)
data[0] = 2.0
assertEquals(tensor[intArrayOf(0)], 2.0)
val tensorData = tensor.getData()
tensorData[0] = 3.0
println(assertEquals(tensor[intArrayOf(0)], 3.0))
}
assertEquals(data[0], 3.0)
}
@Test
fun testRequiresGrad() = TorchTensorRealAlgebra {
val tensor = randNormal(intArrayOf(3))