Copyless data transfer
This commit is contained in:
parent
8967691b7d
commit
524b1d80d1
@ -18,10 +18,15 @@ extern "C"
|
|||||||
|
|
||||||
void set_seed(int seed);
|
void set_seed(int seed);
|
||||||
|
|
||||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device);
|
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device);
|
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device);
|
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device);
|
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||||
|
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||||
|
|
||||||
|
@ -61,9 +61,9 @@ namespace ctorch
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
inline torch::Tensor copy_from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device)
|
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
|
||||||
{
|
{
|
||||||
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true);
|
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename NumType>
|
template <typename NumType>
|
||||||
|
@ -25,6 +25,23 @@ void set_seed(int seed)
|
|||||||
torch::manual_seed(seed);
|
torch::manual_seed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||||
|
}
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||||
|
}
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||||
|
}
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||||
|
}
|
||||||
|
|
||||||
int get_dim(TorchTensorHandle tensor_handle)
|
int get_dim(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
return ctorch::cast(tensor_handle).dim();
|
return ctorch::cast(tensor_handle).dim();
|
||||||
@ -46,21 +63,21 @@ int get_device(TorchTensorHandle tensor_handle)
|
|||||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device)
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device)
|
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device)
|
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device)
|
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
}
|
}
|
||||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
@ -245,7 +262,7 @@ TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
|||||||
}
|
}
|
||||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j));
|
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i, j));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||||
|
@ -3,17 +3,21 @@ package kscience.kmath.torch
|
|||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.ctorch.*
|
import kscience.kmath.ctorch.*
|
||||||
|
|
||||||
public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
|
||||||
internal val scope: DeferScope
|
internal val scope: DeferScope
|
||||||
) {
|
) {
|
||||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
|
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
|
||||||
|
|
||||||
public abstract fun copyFromArray(
|
public abstract fun copyFromArray(
|
||||||
array: PrimitiveArrayType,
|
array: PrimitiveArrayType,
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
device: TorchDevice = TorchDevice.TorchCPU
|
device: TorchDevice = TorchDevice.TorchCPU
|
||||||
): TorchTensor<T>
|
): TorchTensor<T>
|
||||||
|
|
||||||
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
||||||
|
|
||||||
|
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
|
||||||
|
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
|
||||||
|
|
||||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||||
|
|
||||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||||
@ -56,13 +60,13 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
|||||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
|
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<T, PrimitiveArrayType>(scope) {
|
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
|
||||||
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
||||||
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleArray>(scope) {
|
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
|
||||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
@ -76,14 +80,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
): TorchTensorReal =
|
): TorchTensorReal =
|
||||||
TorchTensorReal(
|
TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = copy_from_blob_double(
|
tensorHandle = from_blob_double(
|
||||||
array.toCValues(),
|
array.toCValues(),
|
||||||
shape.toCValues(),
|
shape.toCValues(),
|
||||||
shape.size,
|
shape.size,
|
||||||
device.toInt()
|
device.toInt(),
|
||||||
|
true
|
||||||
)!!
|
)!!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||||
|
TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = from_blob_double(
|
||||||
|
arrayBlob,
|
||||||
|
shape.toCValues(),
|
||||||
|
shape.size,
|
||||||
|
TorchDevice.TorchCPU.toInt(),
|
||||||
|
false
|
||||||
|
)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
|
||||||
|
require(this.device is TorchDevice.TorchCPU){
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_double(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
|
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||||
@ -116,5 +140,5 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
memScoped { TorchTensorRealAlgebra(this).block() }
|
@ -1,5 +1,6 @@
|
|||||||
package kscience.kmath.torch
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlinx.cinterop.*
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
@ -20,6 +21,22 @@ class TestTorchTensor {
|
|||||||
@Test
|
@Test
|
||||||
fun testCopyFromArray() = testingCopyFromArray()
|
fun testCopyFromArray() = testingCopyFromArray()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCopyLessDataTransferOnCPU() = memScoped {
|
||||||
|
val data = allocArray<DoubleVar>(1)
|
||||||
|
data[0] = 1.0
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val tensor = fromBlob(data, intArrayOf(1))
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 1.0)
|
||||||
|
data[0] = 2.0
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 2.0)
|
||||||
|
val tensorData = tensor.getData()
|
||||||
|
tensorData[0] = 3.0
|
||||||
|
println(assertEquals(tensor[intArrayOf(0)], 3.0))
|
||||||
|
}
|
||||||
|
assertEquals(data[0], 3.0)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||||
val tensor = randNormal(intArrayOf(3))
|
val tensor = randNormal(intArrayOf(3))
|
||||||
|
Loading…
Reference in New Issue
Block a user