forked from kscience/kmath
Copyless data transfer
This commit is contained in:
parent
8967691b7d
commit
524b1d80d1
@ -18,10 +18,15 @@ extern "C"
|
||||
|
||||
void set_seed(int seed);
|
||||
|
||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device);
|
||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device);
|
||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device);
|
||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device);
|
||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||
|
||||
|
@ -61,9 +61,9 @@ namespace ctorch
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor copy_from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device)
|
||||
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
|
||||
{
|
||||
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true);
|
||||
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
||||
}
|
||||
|
||||
template <typename NumType>
|
||||
|
@ -25,6 +25,23 @@ void set_seed(int seed)
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||
}
|
||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||
}
|
||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||
}
|
||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||
}
|
||||
|
||||
int get_dim(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).dim();
|
||||
@ -46,21 +63,21 @@ int get_device(TorchTensorHandle tensor_handle)
|
||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||
}
|
||||
|
||||
TorchTensorHandle copy_from_blob_double(double *data, int *shape, int dim, int device)
|
||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
||||
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_float(float *data, int *shape, int dim, int device)
|
||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
||||
return new torch::Tensor(ctorch::from_blob<float>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_long(long *data, int *shape, int dim, int device)
|
||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
||||
return new torch::Tensor(ctorch::from_blob<long>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle copy_from_blob_int(int *data, int *shape, int dim, int device)
|
||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy)
|
||||
{
|
||||
return new torch::Tensor(ctorch::copy_from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device)));
|
||||
return new torch::Tensor(ctorch::from_blob<int>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||
}
|
||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
|
@ -3,17 +3,21 @@ package kscience.kmath.torch
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.ctorch.*
|
||||
|
||||
public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
||||
public sealed class TorchTensorAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType> constructor(
|
||||
internal val scope: DeferScope
|
||||
) {
|
||||
internal abstract fun wrap(tensorHandle: COpaquePointer): TorchTensor<T>
|
||||
|
||||
public abstract fun copyFromArray(
|
||||
array: PrimitiveArrayType,
|
||||
shape: IntArray,
|
||||
device: TorchDevice = TorchDevice.TorchCPU
|
||||
): TorchTensor<T>
|
||||
|
||||
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
||||
|
||||
public abstract fun fromBlob(arrayBlob: CPointer<TVar>, shape: IntArray): TorchTensor<T>
|
||||
public abstract fun TorchTensor<T>.getData(): CPointer<TVar>
|
||||
|
||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||
|
||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||
@ -56,13 +60,13 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, PrimitiveArrayType>(scope) {
|
||||
public sealed class TorchTensorFieldAlgebra<T, TVar: CPrimitiveVar, PrimitiveArrayType>(scope: DeferScope) :
|
||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType>(scope) {
|
||||
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
||||
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensor<T>
|
||||
}
|
||||
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleArray>(scope) {
|
||||
public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra<Double, DoubleVar, DoubleArray>(scope) {
|
||||
override fun wrap(tensorHandle: COpaquePointer): TorchTensorReal =
|
||||
TorchTensorReal(scope = scope, tensorHandle = tensorHandle)
|
||||
|
||||
@ -76,14 +80,34 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
||||
): TorchTensorReal =
|
||||
TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = copy_from_blob_double(
|
||||
tensorHandle = from_blob_double(
|
||||
array.toCValues(),
|
||||
shape.toCValues(),
|
||||
shape.size,
|
||||
device.toInt()
|
||||
device.toInt(),
|
||||
true
|
||||
)!!
|
||||
)
|
||||
|
||||
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||
TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = from_blob_double(
|
||||
arrayBlob,
|
||||
shape.toCValues(),
|
||||
shape.size,
|
||||
TorchDevice.TorchCPU.toInt(),
|
||||
false
|
||||
)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.getData(): CPointer<DoubleVar> {
|
||||
require(this.device is TorchDevice.TorchCPU){
|
||||
"This tensor is not on available on CPU"
|
||||
}
|
||||
return get_data_double(this.tensorHandle)!!
|
||||
}
|
||||
|
||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||
@ -116,5 +140,5 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
||||
|
||||
}
|
||||
|
||||
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kotlin.test.*
|
||||
|
||||
internal fun testingCopyFromArray(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
@ -20,6 +21,22 @@ class TestTorchTensor {
|
||||
@Test
|
||||
fun testCopyFromArray() = testingCopyFromArray()
|
||||
|
||||
@Test
|
||||
fun testCopyLessDataTransferOnCPU() = memScoped {
|
||||
val data = allocArray<DoubleVar>(1)
|
||||
data[0] = 1.0
|
||||
TorchTensorRealAlgebra {
|
||||
val tensor = fromBlob(data, intArrayOf(1))
|
||||
assertEquals(tensor[intArrayOf(0)], 1.0)
|
||||
data[0] = 2.0
|
||||
assertEquals(tensor[intArrayOf(0)], 2.0)
|
||||
val tensorData = tensor.getData()
|
||||
tensorData[0] = 3.0
|
||||
println(assertEquals(tensor[intArrayOf(0)], 3.0))
|
||||
}
|
||||
assertEquals(data[0], 3.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||
val tensor = randNormal(intArrayOf(3))
|
||||
|
Loading…
Reference in New Issue
Block a user