Float, Long and Int support
This commit is contained in:
parent
80f28dbcd5
commit
ef570254e6
@ -20,22 +20,21 @@ extern "C"
|
|||||||
|
|
||||||
TorchTensorHandle empty_tensor();
|
TorchTensorHandle empty_tensor();
|
||||||
|
|
||||||
double *get_data_double(TorchTensorHandle tensor_handle);
|
|
||||||
float *get_data_float(TorchTensorHandle tensor_handle);
|
|
||||||
long *get_data_long(TorchTensorHandle tensor_handle);
|
|
||||||
int *get_data_int(TorchTensorHandle tensor_handle);
|
|
||||||
|
|
||||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy);
|
||||||
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
TorchTensorHandle from_blob_float(float *data, int *shape, int dim, int device, bool copy);
|
||||||
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
TorchTensorHandle from_blob_long(long *data, int *shape, int dim, int device, bool copy);
|
||||||
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
TorchTensorHandle from_blob_int(int *data, int *shape, int dim, int device, bool copy);
|
||||||
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle copy_tensor(TorchTensorHandle tensor_handle);
|
||||||
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device);
|
||||||
|
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle);
|
||||||
|
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle);
|
||||||
|
|
||||||
double get_item_double(TorchTensorHandle tensor_handle);
|
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
||||||
float get_item_float(TorchTensorHandle tensor_handle);
|
void dispose_char(char *ptr);
|
||||||
long get_item_long(TorchTensorHandle tensor_handle);
|
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||||
int get_item_int(TorchTensorHandle tensor_handle);
|
|
||||||
|
|
||||||
int get_dim(TorchTensorHandle tensor_handle);
|
int get_dim(TorchTensorHandle tensor_handle);
|
||||||
int get_numel(TorchTensorHandle tensor_handle);
|
int get_numel(TorchTensorHandle tensor_handle);
|
||||||
@ -43,9 +42,15 @@ extern "C"
|
|||||||
int get_stride_at(TorchTensorHandle tensor_handle, int d);
|
int get_stride_at(TorchTensorHandle tensor_handle, int d);
|
||||||
int get_device(TorchTensorHandle tensor_handle);
|
int get_device(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
char *tensor_to_string(TorchTensorHandle tensor_handle);
|
double *get_data_double(TorchTensorHandle tensor_handle);
|
||||||
void dispose_char(char *ptr);
|
float *get_data_float(TorchTensorHandle tensor_handle);
|
||||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
long *get_data_long(TorchTensorHandle tensor_handle);
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double get_item_double(TorchTensorHandle tensor_handle);
|
||||||
|
float get_item_float(TorchTensorHandle tensor_handle);
|
||||||
|
long get_item_long(TorchTensorHandle tensor_handle);
|
||||||
|
int get_item_int(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
double get_double(TorchTensorHandle tensor_handle, int *index);
|
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||||
float get_float(TorchTensorHandle tensor_handle, int *index);
|
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||||
@ -61,15 +66,14 @@ extern "C"
|
|||||||
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||||
|
|
||||||
|
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device);
|
||||||
|
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device);
|
||||||
|
|
||||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||||
|
|
||||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
|
||||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
|
||||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
|
||||||
|
|
||||||
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||||
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||||
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||||
@ -115,10 +119,9 @@ extern "C"
|
|||||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||||
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
void sum_tensor_assign(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
|
||||||
|
|
||||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2);
|
||||||
|
|
||||||
@ -132,6 +135,11 @@ extern "C"
|
|||||||
TorchTensorHandle V_handle,
|
TorchTensorHandle V_handle,
|
||||||
bool eigenvectors);
|
bool eigenvectors);
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -92,6 +92,12 @@ namespace ctorch
|
|||||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor randint(Dtype low, Dtype high, std::vector<int64_t> shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::randint(low, high, shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
||||||
{
|
{
|
||||||
|
@ -30,44 +30,6 @@ TorchTensorHandle empty_tensor()
|
|||||||
return new torch::Tensor;
|
return new torch::Tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
double *get_data_double(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).data_ptr<double>();
|
|
||||||
}
|
|
||||||
float *get_data_float(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).data_ptr<float>();
|
|
||||||
}
|
|
||||||
long *get_data_long(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).data_ptr<long>();
|
|
||||||
}
|
|
||||||
int *get_data_int(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).data_ptr<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
int get_dim(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).dim();
|
|
||||||
}
|
|
||||||
int get_numel(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).numel();
|
|
||||||
}
|
|
||||||
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).size(d);
|
|
||||||
}
|
|
||||||
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).stride(d);
|
|
||||||
}
|
|
||||||
int get_device(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
TorchTensorHandle from_blob_double(double *data, int *shape, int dim, int device, bool copy)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
return new torch::Tensor(ctorch::from_blob<double>(data, ctorch::to_vec_int(shape, dim), ctorch::int_to_device(device), copy));
|
||||||
@ -92,6 +54,26 @@ TorchTensorHandle copy_to_device(TorchTensorHandle tensor_handle, int device)
|
|||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::int_to_device(device), false, true));
|
||||||
}
|
}
|
||||||
|
TorchTensorHandle copy_to_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<double>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<float>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<long>(), false, true));
|
||||||
|
}
|
||||||
|
TorchTensorHandle copy_to_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).to(ctorch::dtype<int>(), false, true));
|
||||||
|
}
|
||||||
|
void swap_tensors(TorchTensorHandle lhs_handle, TorchTensorHandle rhs_handle)
|
||||||
|
{
|
||||||
|
std::swap(ctorch::cast(lhs_handle), ctorch::cast(rhs_handle));
|
||||||
|
}
|
||||||
|
|
||||||
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
char *tensor_to_string(TorchTensorHandle tensor_handle)
|
||||||
{
|
{
|
||||||
@ -111,6 +93,61 @@ void dispose_tensor(TorchTensorHandle tensor_handle)
|
|||||||
delete static_cast<torch::Tensor *>(tensor_handle);
|
delete static_cast<torch::Tensor *>(tensor_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int get_dim(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).dim();
|
||||||
|
}
|
||||||
|
int get_numel(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).numel();
|
||||||
|
}
|
||||||
|
int get_shape_at(TorchTensorHandle tensor_handle, int d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).size(d);
|
||||||
|
}
|
||||||
|
int get_stride_at(TorchTensorHandle tensor_handle, int d)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).stride(d);
|
||||||
|
}
|
||||||
|
int get_device(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::device_to_int(ctorch::cast(tensor_handle));
|
||||||
|
}
|
||||||
|
|
||||||
|
double *get_data_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<double>();
|
||||||
|
}
|
||||||
|
float *get_data_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<float>();
|
||||||
|
}
|
||||||
|
long *get_data_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<long>();
|
||||||
|
}
|
||||||
|
int *get_data_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).data_ptr<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
double get_item_double(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<double>();
|
||||||
|
}
|
||||||
|
float get_item_float(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<float>();
|
||||||
|
}
|
||||||
|
long get_item_long(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<long>();
|
||||||
|
}
|
||||||
|
int get_item_int(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).item<int>();
|
||||||
|
}
|
||||||
|
|
||||||
double get_double(TorchTensorHandle tensor_handle, int *index)
|
double get_double(TorchTensorHandle tensor_handle, int *index)
|
||||||
{
|
{
|
||||||
return ctorch::get<double>(tensor_handle, index);
|
return ctorch::get<double>(tensor_handle, index);
|
||||||
@ -144,23 +181,6 @@ void set_int(TorchTensorHandle tensor_handle, int *index, int value)
|
|||||||
ctorch::set<int>(tensor_handle, index, value);
|
ctorch::set<int>(tensor_handle, index, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
double get_item_double(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).item<double>();
|
|
||||||
}
|
|
||||||
float get_item_float(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).item<float>();
|
|
||||||
}
|
|
||||||
long get_item_long(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).item<long>();
|
|
||||||
}
|
|
||||||
int get_item_int(TorchTensorHandle tensor_handle)
|
|
||||||
{
|
|
||||||
return ctorch::cast(tensor_handle).item<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchTensorHandle randn_double(int *shape, int shape_size, int device)
|
TorchTensorHandle randn_double(int *shape, int shape_size, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::randn<double>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
@ -178,6 +198,15 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
|||||||
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle randint_long(long low, long high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<long>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle randint_int(int low, int high, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::randint<int>(low, high, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
@ -195,19 +224,6 @@ TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
|||||||
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|
||||||
{
|
|
||||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
|
||||||
}
|
|
||||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|
||||||
{
|
|
||||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
|
||||||
}
|
|
||||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
|
||||||
{
|
|
||||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
TorchTensorHandle plus_double(double value, TorchTensorHandle other)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(other) + value);
|
return new torch::Tensor(ctorch::cast(other) + value);
|
||||||
@ -356,21 +372,17 @@ void sum_tensor_assign(TorchTensorHandle tensor_handle)
|
|||||||
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
ctorch::cast(tensor_handle) = ctorch::cast(tensor_handle).sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
{
|
{
|
||||||
return ctorch::cast(tensor_handle).requires_grad();
|
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||||
}
|
}
|
||||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
{
|
{
|
||||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
}
|
}
|
||||||
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
}
|
|
||||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
|
||||||
{
|
|
||||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
TorchTensorHandle diag_embed(TorchTensorHandle diags_handle, int offset, int dim1, int dim2)
|
||||||
@ -398,3 +410,20 @@ void symeig_tensor(TorchTensorHandle tensor_handle,
|
|||||||
ctorch::cast(S_handle) = S;
|
ctorch::cast(S_handle) = S;
|
||||||
ctorch::cast(V_handle) = V;
|
ctorch::cast(V_handle) = V;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).requires_grad();
|
||||||
|
}
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||||
|
}
|
||||||
|
TorchTensorHandle detach_from_graph(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).detach());
|
||||||
|
}
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||||
|
}
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
class BenchmarkMatMultFloatGPU {
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult20() =
|
||||||
|
benchmarkingMatMultFloat(20, 10, 100000,
|
||||||
|
device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult200() =
|
||||||
|
benchmarkingMatMultFloat(200, 10, 10000,
|
||||||
|
device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult2000() =
|
||||||
|
benchmarkingMatMultFloat(2000, 10, 1000,
|
||||||
|
device = TorchDevice.TorchCUDA(0))
|
||||||
|
}
|
@ -29,5 +29,4 @@ class TestTorchTensorAlgebraGPU {
|
|||||||
fun testBatchedSymEig() =
|
fun testBatchedSymEig() =
|
||||||
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
|
testingBatchedSymEig(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,11 @@ package kscience.kmath.torch
|
|||||||
|
|
||||||
|
|
||||||
public sealed class TorchDevice {
|
public sealed class TorchDevice {
|
||||||
public object TorchCPU: TorchDevice()
|
public object TorchCPU: TorchDevice() {
|
||||||
|
override fun toString(): String {
|
||||||
|
return "TorchCPU"
|
||||||
|
}
|
||||||
|
}
|
||||||
public data class TorchCUDA(val index: Int): TorchDevice()
|
public data class TorchCUDA(val index: Int): TorchDevice()
|
||||||
public fun toInt(): Int {
|
public fun toInt(): Int {
|
||||||
when(this) {
|
when(this) {
|
||||||
|
@ -53,6 +53,23 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
get() = requires_grad(tensorHandle)
|
get() = requires_grad(tensorHandle)
|
||||||
set(value) = requires_grad_(tensorHandle, value)
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
|
|
||||||
|
public fun copyToDouble(): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_double(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_float(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
public fun copyToLong(): TorchTensorLong = TorchTensorLong(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_long(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
public fun copyToInt(): TorchTensorInt = TorchTensorInt(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = copy_to_int(this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorReal internal constructor(
|
public class TorchTensorReal internal constructor(
|
||||||
@ -66,6 +83,39 @@ public class TorchTensorReal internal constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class TorchTensorFloat internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensor<Float>(scope, tensorHandle) {
|
||||||
|
override fun item(): Float = get_item_float(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Float = get_float(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Float) {
|
||||||
|
set_float(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLong internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensor<Long>(scope, tensorHandle) {
|
||||||
|
override fun item(): Long = get_item_long(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Long = get_long(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Long) {
|
||||||
|
set_long(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorInt internal constructor(
|
||||||
|
scope: DeferScope,
|
||||||
|
tensorHandle: COpaquePointer
|
||||||
|
) : TorchTensor<Int>(scope, tensorHandle) {
|
||||||
|
override fun item(): Int = get_item_int(tensorHandle)
|
||||||
|
override fun get(index: IntArray): Int = get_int(tensorHandle, index.toCValues())
|
||||||
|
override fun set(index: IntArray, value: Int) {
|
||||||
|
set_int(tensorHandle, index.toCValues(), value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
private inline fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||||
val res = IntArray(nDim)
|
val res = IntArray(nDim)
|
||||||
|
@ -42,13 +42,6 @@ public sealed class TorchTensorAlgebra<
|
|||||||
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
times_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
|
|
||||||
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
|
||||||
|
|
||||||
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
|
||||||
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
|
public infix fun TorchTensorType.dot(other: TorchTensorType): TorchTensorType =
|
||||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
@ -94,10 +87,47 @@ public sealed class TorchTensorAlgebra<
|
|||||||
sum_tensor_assign(tensorHandle)
|
sum_tensor_assign(tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun diagEmbed(diagonalEntries: TorchTensorType,
|
public fun diagEmbed(
|
||||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1): TorchTensorType =
|
diagonalEntries: TorchTensorType,
|
||||||
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
|
): TorchTensorType =
|
||||||
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
wrap(diag_embed(diagonalEntries.tensorHandle, offset, dim1, dim2)!!)
|
||||||
|
|
||||||
|
public fun TorchTensorType.copy(): TorchTensorType =
|
||||||
|
wrap(copy_tensor(this.tensorHandle)!!)
|
||||||
|
|
||||||
|
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
||||||
|
wrap(copy_to_device(this.tensorHandle, device.toInt())!!)
|
||||||
|
|
||||||
|
public infix fun TorchTensorType.swap(otherTensor: TorchTensorType): Unit {
|
||||||
|
swap_tensors(this.tensorHandle, otherTensor.tensorHandle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
||||||
|
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||||
|
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
||||||
|
|
||||||
|
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||||
|
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
||||||
|
|
||||||
|
public operator fun TorchTensorType.div(other: TorchTensorType): TorchTensorType =
|
||||||
|
wrap(div_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
public operator fun TorchTensorType.divAssign(other: TorchTensorType): Unit {
|
||||||
|
div_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.expAssign(): Unit {
|
||||||
|
exp_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
||||||
|
public fun TorchTensorType.logAssign(): Unit {
|
||||||
|
log_tensor_assign(tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType> {
|
||||||
val U = empty_tensor()!!
|
val U = empty_tensor()!!
|
||||||
val V = empty_tensor()!!
|
val V = empty_tensor()!!
|
||||||
@ -116,33 +146,17 @@ public sealed class TorchTensorAlgebra<
|
|||||||
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
public infix fun TorchTensorType.grad(variable: TorchTensorType): TorchTensorType =
|
||||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
|
|
||||||
public fun TorchTensorType.copy(): TorchTensorType =
|
|
||||||
wrap(tensorHandle = copy_tensor(this.tensorHandle)!!)
|
|
||||||
|
|
||||||
public fun TorchTensorType.copyToDevice(device: TorchDevice): TorchTensorType =
|
|
||||||
wrap(tensorHandle = copy_to_device(this.tensorHandle, device.toInt())!!)
|
|
||||||
|
|
||||||
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
public fun TorchTensorType.detachFromGraph(): TorchTensorType =
|
||||||
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
wrap(tensorHandle = detach_from_graph(this.tensorHandle)!!)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, TVar : CPrimitiveVar,
|
public sealed class TorchTensorRingAlgebra<T, TVar : CPrimitiveVar,
|
||||||
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
PrimitiveArrayType, TorchTensorType : TorchTensor<T>>(scope: DeferScope) :
|
||||||
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
TorchTensorAlgebra<T, TVar, PrimitiveArrayType, TorchTensorType>(scope) {
|
||||||
public abstract fun randNormal(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
public abstract fun randIntegral(
|
||||||
public abstract fun randUniform(shape: IntArray, device: TorchDevice = TorchDevice.TorchCPU): TorchTensorType
|
low: T, high: T, shape: IntArray,
|
||||||
|
device: TorchDevice = TorchDevice.TorchCPU
|
||||||
public fun TorchTensorType.exp(): TorchTensorType = wrap(exp_tensor(tensorHandle)!!)
|
): TorchTensorType
|
||||||
public fun TorchTensorType.expAssign(): Unit {
|
|
||||||
exp_tensor_assign(tensorHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun TorchTensorType.log(): TorchTensorType = wrap(log_tensor(tensorHandle)!!)
|
|
||||||
public fun TorchTensorType.logAssign(): Unit {
|
|
||||||
log_tensor_assign(tensorHandle)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
public class TorchTensorRealAlgebra(scope: DeferScope) :
|
||||||
@ -153,33 +167,11 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
override fun TorchTensorReal.copyToArray(): DoubleArray =
|
||||||
this.elements().map { it.second }.toList().toDoubleArray()
|
this.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
override fun copyFromArray(
|
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||||
array: DoubleArray,
|
wrap(from_blob_double(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
shape: IntArray,
|
|
||||||
device: TorchDevice
|
|
||||||
): TorchTensorReal =
|
|
||||||
TorchTensorReal(
|
|
||||||
scope = scope,
|
|
||||||
tensorHandle = from_blob_double(
|
|
||||||
array.toCValues(),
|
|
||||||
shape.toCValues(),
|
|
||||||
shape.size,
|
|
||||||
device.toInt(),
|
|
||||||
true
|
|
||||||
)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
override fun fromBlob(arrayBlob: CPointer<DoubleVar>, shape: IntArray): TorchTensorReal =
|
||||||
TorchTensorReal(
|
wrap(from_blob_double(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||||
scope = scope,
|
|
||||||
tensorHandle = from_blob_double(
|
|
||||||
arrayBlob,
|
|
||||||
shape.toCValues(),
|
|
||||||
shape.size,
|
|
||||||
TorchDevice.TorchCPU.toInt(),
|
|
||||||
false
|
|
||||||
)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
override fun TorchTensorReal.getData(): CPointer<DoubleVar> {
|
||||||
require(this.device is TorchDevice.TorchCPU) {
|
require(this.device is TorchDevice.TorchCPU) {
|
||||||
@ -188,65 +180,231 @@ public class TorchTensorRealAlgebra(scope: DeferScope) :
|
|||||||
return get_data_double(this.tensorHandle)!!
|
return get_data_double(this.tensorHandle)!!
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(randn_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
tensorHandle = randn_double(shape.toCValues(), shape.size, device.toInt())!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(rand_double(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.plus(other: TorchTensorReal): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(plus_double(this, other.tensorHandle)!!)
|
||||||
tensorHandle = plus_double(this, other.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.plus(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.plus(value: Double): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(plus_double(value, this.tensorHandle)!!)
|
||||||
tensorHandle = plus_double(value, this.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
override fun TorchTensorReal.plusAssign(value: Double): Unit {
|
||||||
plus_double_assign(value, this.tensorHandle)
|
plus_double_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.minus(other: TorchTensorReal): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(plus_double(-this, other.tensorHandle)!!)
|
||||||
tensorHandle = plus_double(-this, other.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.minus(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.minus(value: Double): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(plus_double(-value, this.tensorHandle)!!)
|
||||||
tensorHandle = plus_double(-value, this.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
override fun TorchTensorReal.minusAssign(value: Double): Unit {
|
||||||
plus_double_assign(-value, this.tensorHandle)
|
plus_double_assign(-value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal = TorchTensorReal(
|
override operator fun Double.times(other: TorchTensorReal): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(times_double(this, other.tensorHandle)!!)
|
||||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.times(value: Double): TorchTensorReal = TorchTensorReal(
|
override fun TorchTensorReal.times(value: Double): TorchTensorReal =
|
||||||
scope = scope,
|
wrap(times_double(value, this.tensorHandle)!!)
|
||||||
tensorHandle = times_double(value, this.tensorHandle)!!
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
override fun TorchTensorReal.timesAssign(value: Double): Unit {
|
||||||
times_double_assign(value, this.tensorHandle)
|
times_double_assign(value, this.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal =
|
||||||
|
wrap(full_double(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
public class TorchTensorFloatAlgebra(scope: DeferScope) :
|
||||||
scope = scope,
|
TorchTensorFieldAlgebra<Float, FloatVar, FloatArray, TorchTensorFloat>(scope) {
|
||||||
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorFloat =
|
||||||
)
|
TorchTensorFloat(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.copyToArray(): FloatArray =
|
||||||
|
this.elements().map { it.second }.toList().toFloatArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: FloatArray, shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||||
|
wrap(from_blob_float(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<FloatVar>, shape: IntArray): TorchTensorFloat =
|
||||||
|
wrap(from_blob_float(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.getData(): CPointer<FloatVar> {
|
||||||
|
require(this.device is TorchDevice.TorchCPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_float(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randNormal(shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||||
|
wrap(randn_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override fun randUniform(shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||||
|
wrap(rand_float(shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Float.plus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(plus_float(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(plus_float(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.plusAssign(value: Float): Unit {
|
||||||
|
plus_float_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Float.minus(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(plus_float(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minus(value: Float): TorchTensorFloat =
|
||||||
|
wrap(plus_float(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.minusAssign(value: Float): Unit {
|
||||||
|
plus_float_assign(-value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Float.times(other: TorchTensorFloat): TorchTensorFloat =
|
||||||
|
wrap(times_float(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.times(value: Float): TorchTensorFloat =
|
||||||
|
wrap(times_float(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorFloat.timesAssign(value: Float): Unit {
|
||||||
|
times_float_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun full(value: Float, shape: IntArray, device: TorchDevice): TorchTensorFloat =
|
||||||
|
wrap(full_float(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorLongAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorRingAlgebra<Long, LongVar, LongArray, TorchTensorLong>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorLong =
|
||||||
|
TorchTensorLong(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.copyToArray(): LongArray =
|
||||||
|
this.elements().map { it.second }.toList().toLongArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: LongArray, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||||
|
wrap(from_blob_long(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<LongVar>, shape: IntArray): TorchTensorLong =
|
||||||
|
wrap(from_blob_long(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.getData(): CPointer<LongVar> {
|
||||||
|
require(this.device is TorchDevice.TorchCPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_long(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randIntegral(low: Long, high: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||||
|
wrap(randint_long(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Long.plus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(plus_long(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plus(value: Long): TorchTensorLong =
|
||||||
|
wrap(plus_long(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.plusAssign(value: Long): Unit {
|
||||||
|
plus_long_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Long.minus(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(plus_long(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minus(value: Long): TorchTensorLong =
|
||||||
|
wrap(plus_long(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.minusAssign(value: Long): Unit {
|
||||||
|
plus_long_assign(-value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Long.times(other: TorchTensorLong): TorchTensorLong =
|
||||||
|
wrap(times_long(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.times(value: Long): TorchTensorLong =
|
||||||
|
wrap(times_long(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorLong.timesAssign(value: Long): Unit {
|
||||||
|
times_long_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun full(value: Long, shape: IntArray, device: TorchDevice): TorchTensorLong =
|
||||||
|
wrap(full_long(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TorchTensorIntAlgebra(scope: DeferScope) :
|
||||||
|
TorchTensorRingAlgebra<Int, IntVar, IntArray, TorchTensorInt>(scope) {
|
||||||
|
override fun wrap(tensorHandle: COpaquePointer): TorchTensorInt =
|
||||||
|
TorchTensorInt(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.copyToArray(): IntArray =
|
||||||
|
this.elements().map { it.second }.toList().toIntArray()
|
||||||
|
|
||||||
|
override fun copyFromArray(array: IntArray, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||||
|
wrap(from_blob_int(array.toCValues(), shape.toCValues(), shape.size, device.toInt(), true)!!)
|
||||||
|
|
||||||
|
override fun fromBlob(arrayBlob: CPointer<IntVar>, shape: IntArray): TorchTensorInt =
|
||||||
|
wrap(from_blob_int(arrayBlob, shape.toCValues(), shape.size, TorchDevice.TorchCPU.toInt(), false)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.getData(): CPointer<IntVar> {
|
||||||
|
require(this.device is TorchDevice.TorchCPU) {
|
||||||
|
"This tensor is not on available on CPU"
|
||||||
|
}
|
||||||
|
return get_data_int(this.tensorHandle)!!
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun randIntegral(low: Int, high: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||||
|
wrap(randint_int(low, high, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
|
|
||||||
|
override operator fun Int.plus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(plus_int(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plus(value: Int): TorchTensorInt =
|
||||||
|
wrap(plus_int(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.plusAssign(value: Int): Unit {
|
||||||
|
plus_int_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Int.minus(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(plus_int(-this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minus(value: Int): TorchTensorInt =
|
||||||
|
wrap(plus_int(-value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.minusAssign(value: Int): Unit {
|
||||||
|
plus_int_assign(-value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun Int.times(other: TorchTensorInt): TorchTensorInt =
|
||||||
|
wrap(times_int(this, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.times(value: Int): TorchTensorInt =
|
||||||
|
wrap(times_int(value, this.tensorHandle)!!)
|
||||||
|
|
||||||
|
override fun TorchTensorInt.timesAssign(value: Int): Unit {
|
||||||
|
times_int_assign(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun full(value: Int, shape: IntArray, device: TorchDevice): TorchTensorInt =
|
||||||
|
wrap(full_int(value, shape.toCValues(), shape.size, device.toInt())!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public inline fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
memScoped { TorchTensorRealAlgebra(this).block() }
|
memScoped { TorchTensorRealAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorFloatAlgebra(block: TorchTensorFloatAlgebra.() -> R): R =
|
||||||
|
memScoped { TorchTensorFloatAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorLongAlgebra(block: TorchTensorLongAlgebra.() -> R): R =
|
||||||
|
memScoped { TorchTensorLongAlgebra(this).block() }
|
||||||
|
|
||||||
|
public inline fun <R> TorchTensorIntAlgebra(block: TorchTensorIntAlgebra.() -> R): R =
|
||||||
|
memScoped { TorchTensorIntAlgebra(this).block() }
|
@ -3,14 +3,14 @@ package kscience.kmath.torch
|
|||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.time.measureTime
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
internal fun benchmarkingDoubleMatrixMultiplication(
|
internal fun benchmarkingMatMultDouble(
|
||||||
scale: Int,
|
scale: Int,
|
||||||
numWarmUp: Int,
|
numWarmUp: Int,
|
||||||
numIter: Int,
|
numIter: Int,
|
||||||
device: TorchDevice = TorchDevice.TorchCPU
|
device: TorchDevice = TorchDevice.TorchCPU
|
||||||
): Unit {
|
): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
println("Benchmarking $scale x $scale matrices over Double's: ")
|
println("Benchmarking $scale x $scale matrices over Double's on $device: ")
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
@ -20,15 +20,18 @@ internal fun benchmarkingDoubleMatrixMultiplication(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class BenchmarkMatrixMultiplicationDouble {
|
internal class BenchmarkMatMultDouble {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 10, 100000)
|
fun benchmarkMatMult20() =
|
||||||
|
benchmarkingMatMultDouble(20, 10, 100000)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10, 10000)
|
fun benchmarkMatMult200() =
|
||||||
|
benchmarkingMatMultDouble(200, 10, 10000)
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 3, 20)
|
fun benchmarkMatMult2000() =
|
||||||
|
benchmarkingMatMultDouble(2000, 3, 20)
|
||||||
|
|
||||||
}
|
}
|
@ -0,0 +1,36 @@
|
|||||||
|
package kscience.kmath.torch
|
||||||
|
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
|
internal fun benchmarkingMatMultFloat(
|
||||||
|
scale: Int,
|
||||||
|
numWarmUp: Int,
|
||||||
|
numIter: Int,
|
||||||
|
device: TorchDevice = TorchDevice.TorchCPU
|
||||||
|
): Unit {
|
||||||
|
TorchTensorFloatAlgebra {
|
||||||
|
println("Benchmarking $scale x $scale matrices over Float's on $device: ")
|
||||||
|
setSeed(SEED)
|
||||||
|
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
|
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||||
|
repeat(numWarmUp) { lhs dotAssign rhs }
|
||||||
|
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||||
|
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class BenchmarkMatMultFloat {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult20() =
|
||||||
|
benchmarkingMatMultFloat(20, 10, 100000)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult200() =
|
||||||
|
benchmarkingMatMultFloat(200, 10, 10000)
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun benchmarkMatMult2000() =
|
||||||
|
benchmarkingMatMultFloat(2000, 3, 20)
|
||||||
|
}
|
@ -44,7 +44,6 @@ internal fun testingBatchedAutoGrad(bath: IntArray,
|
|||||||
|
|
||||||
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
val error = (gradientAtX - expectedGradientAtX).abs().sum().value()
|
||||||
assertTrue(error < TOLERANCE)
|
assertTrue(error < TOLERANCE)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,4 +49,15 @@ class TestTorchTensor {
|
|||||||
val detachedTensor = tensor.detachFromGraph()
|
val detachedTensor = tensor.detachFromGraph()
|
||||||
assertTrue(!detachedTensor.requiresGrad)
|
assertTrue(!detachedTensor.requiresGrad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTypeMoving() = TorchTensorFloatAlgebra {
|
||||||
|
val tensorInt = copyFromArray(floatArrayOf(1f,2f,3f), intArrayOf(3)).copyToInt()
|
||||||
|
TorchTensorIntAlgebra {
|
||||||
|
val temporalTensor = copyFromArray(intArrayOf(4,5,6),intArrayOf(3))
|
||||||
|
tensorInt swap temporalTensor
|
||||||
|
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1,2,3))
|
||||||
|
}
|
||||||
|
assertTrue(tensorInt.copyToFloat().copyToArray() contentEquals floatArrayOf(4f,5f,6f))
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user