Copyless data transfer

This commit is contained in:
rgrit91 2021-01-12 19:49:16 +00:00
parent 8967691b7d
commit 524b1d80d1
5 changed files with 87 additions and 24 deletions

View File

@ -18,10 +18,15 @@ extern "C"
void set_seed(int seed); void set_seed(int seed);
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device); double *get_data_double(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device); float *get_data_float(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device); long *get_data_long(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device); int *get_data_int(TorchTensorHandle tensor_handle);
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle); TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device); TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);

View File

@ -61,9 +61,9 @@ namespace ctorch
} }
template <typename Dtype> template <typename Dtype>
inline torch::Tensor copy_from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device) inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
{ {
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true); return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
} }
template <typename NumType> template <typename NumType>

View File

@ -25,6 +25,23 @@ void set_seed(int seed)
torch::manual_seed(seed); torch::manual_seed(seed);
} }
double *get_data_double(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<double>();
}
float *get_data_float(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<float>();
}
long *get_data_long(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<long>();
}
int *get_data_int(TorchTensorHandle tensor_handle)
{
return ctorch::cast(tensor_handle).data_ptr<int>();
}
int get_dim(TorchTensorHandle tensor_handle) int get_dim(TorchTensorHandle tensor_handle)
{ {
return ctorch::cast(tensor_handle).dim(); return ctorch::cast(tensor_handle).dim();
@ -46,21 +63,21 @@ int get_device(TorchTensorHandle tensor_handle)
return ctorch::device_to_int(ctorch::cast(tensor_handle)); return ctorch::device_to_int(ctorch::cast(tensor_handle));
} }
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device) TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
{ {
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
} }
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device) TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
{ {
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
} }
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device) TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
{ {
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
} }
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device) TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
{ {
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device))); return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
} }
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle) TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
{ {
@ -245,7 +262,7 @@ TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
} }
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j) TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
{ {
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j)); return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
} }
bool requires_grad(TorchTensorHandle tensor_handle) bool requires_grad(TorchTensorHandle tensor_handle)

View File

@ -3,17 +3,21 @@ package kscience.kmath.torch
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.ctorch.* import kscience.kmath.ctorch.*
public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor( public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
internal val scope: DeferScope internal val scope: DeferScope
) { ) {
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T> internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
public abstract fun copyFromArray( public abstract fun copyFromArray(
array: PrimitiveArrayType, array: PrimitiveArrayType,
shape: IntArray, shape: IntArray,
device: TorchDevice = TorchDevice.TorchCPU device: TorchDevice = TorchDevice.TorchCPU
): TorchTensor<T> ): TorchTensor<T>
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T> public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T> public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
@ -56,13 +60,13 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!) wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
} }
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) : public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
TorchTensorAlgebra<T, PrimitiveArrayType>(scope) { TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T> public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T> public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
} }
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleArray>(scope) { public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal = override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
TorchTensorReal(scope = scope, tensorHandle = tensorHandle) TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
@ -76,14 +80,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
): TorchTensorReal = ): TorchTensorReal =
TorchTensorReal( TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = copy_from_blob_double( tensorHandle = from_blob_double(
array.toCValues(), array.toCValues(),
shape.toCValues(), shape.toCValues(),
shape.size, shape.size,
device.toInt() device.toInt(),
true
)!! )!!
) )
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
TorchTensorReal(
scope = scope,
tensorHandle = from_blob_double(
arrayBlob,
shape.toCValues(),
shape.size,
TorchDevice.TorchCPU.toInt(),
false
)!!
)
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
require(this.device is TorchDevice.TorchCPU){
"This tensor is not on available on CPU"
}
return get_data_double(this.tensorHandle)!!
}
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal( override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
scope = scope, scope = scope,
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!! tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
@ -116,5 +140,5 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
} }
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R = public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
memScoped { TorchTensorRealAlgebra(this).block() } memScoped { TorchTensorRealAlgebra(this).block() }

View File

@ -1,5 +1,6 @@
package kscience.kmath.torch package kscience.kmath.torch
import kotlinx.cinterop.*
import kotlin.test.* import kotlin.test.*
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit { internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit {
@ -21,7 +22,23 @@ class TestTorchTensor {
fun testCopyFromArray() = testingCopyFromArray() fun testCopyFromArray() = testingCopyFromArray()
@Test @Test
fun testRequiresGrad() = TorchTensorRealAlgebra { fun testCopyLessDataTransferOnCPU() = memScoped {
val data = allocArray<DoubleVar>(1)
data[0] = 1.0
TorchTensorRealAlgebra {
val tensor = fromBlob(data, intArrayOf(1))
assertEquals(tensor[intArrayOf(0)], 1.0)
data[0] = 2.0
assertEquals(tensor[intArrayOf(0)], 2.0)
val tensorData = tensor.getData()
tensorData[0] = 3.0
println(assertEquals(tensor[intArrayOf(0)], 3.0))
}
assertEquals(data[0], 3.0)
}
@Test
fun testRequiresGrad() = TorchTensorRealAlgebra {
val tensor = randNormal(intArrayOf(3)) val tensor = randNormal(intArrayOf(3))
assertTrue(!tensor.requiresGrad) assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true tensor.requiresGrad = true