2020-12-30 01:42:33 +03:00
|
|
|
#include <torch/torch.h>
|
|
|
|
|
|
|
|
#include "ctorch.h"
|
|
|
|
|
|
|
|
namespace ctorch
|
|
|
|
{
|
|
|
|
template <typename Dtype>
|
|
|
|
inline c10::ScalarType dtype()
|
|
|
|
{
|
|
|
|
return torch::kFloat64;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
inline c10::ScalarType dtype<float>()
|
|
|
|
{
|
|
|
|
return torch::kFloat32;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
inline c10::ScalarType dtype<long>()
|
|
|
|
{
|
|
|
|
return torch::kInt64;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
inline c10::ScalarType dtype<int>()
|
|
|
|
{
|
|
|
|
return torch::kInt32;
|
|
|
|
}
|
|
|
|
|
2021-01-06 16:20:48 +03:00
|
|
|
inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle)
|
2020-12-30 01:42:33 +03:00
|
|
|
{
|
|
|
|
return *static_cast<torch::Tensor *>(tensor_handle);
|
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
inline int device_to_int(const torch::Tensor &tensor)
|
2021-01-08 11:27:39 +03:00
|
|
|
{
|
2021-01-09 20:13:38 +03:00
|
|
|
return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index();
|
2021-01-08 11:27:39 +03:00
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
inline torch::Device int_to_device(int device_int)
|
2020-12-30 01:42:33 +03:00
|
|
|
{
|
2021-01-09 20:13:38 +03:00
|
|
|
return (device_int == 0) ? torch::kCPU : torch::Device(torch::kCUDA, device_int - 1);
|
2020-12-30 01:42:33 +03:00
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
inline std::vector<int64_t> to_vec_int(int *arr, int arr_size)
|
2020-12-30 01:42:33 +03:00
|
|
|
{
|
2021-01-09 20:13:38 +03:00
|
|
|
auto vec = std::vector<int64_t>(arr_size);
|
|
|
|
vec.assign(arr, arr + arr_size);
|
|
|
|
return vec;
|
2020-12-30 01:42:33 +03:00
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
inline std::vector<at::indexing::TensorIndex> to_index(int *arr, int arr_size)
|
2021-01-06 16:20:48 +03:00
|
|
|
{
|
|
|
|
std::vector<at::indexing::TensorIndex> index;
|
2021-01-09 20:13:38 +03:00
|
|
|
for (int i = 0; i < arr_size; i++)
|
2021-01-06 16:20:48 +03:00
|
|
|
{
|
2021-01-09 20:13:38 +03:00
|
|
|
index.emplace_back(arr[i]);
|
2021-01-06 16:20:48 +03:00
|
|
|
}
|
|
|
|
return index;
|
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
template <typename Dtype>
|
2021-01-12 22:49:16 +03:00
|
|
|
inline torch::Tensor from_blob(Dtype *data, std::vector<int64_t> shape, torch::Device device, bool copy)
|
2021-01-09 20:13:38 +03:00
|
|
|
{
|
2021-01-12 22:49:16 +03:00
|
|
|
return torch::from_blob(data, shape, dtype<Dtype>()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, copy);
|
2021-01-09 20:13:38 +03:00
|
|
|
}
|
|
|
|
|
2021-01-06 16:20:48 +03:00
|
|
|
template <typename NumType>
|
2021-01-09 20:13:38 +03:00
|
|
|
inline NumType get(const TorchTensorHandle &tensor_handle, int *index)
|
2021-01-06 16:20:48 +03:00
|
|
|
{
|
|
|
|
auto ten = ctorch::cast(tensor_handle);
|
2021-01-09 20:13:38 +03:00
|
|
|
return ten.index(to_index(index, ten.dim())).item<NumType>();
|
2021-01-06 16:20:48 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename NumType>
|
2021-01-09 20:13:38 +03:00
|
|
|
inline void set(TorchTensorHandle &tensor_handle, int *index, NumType value)
|
2021-01-06 16:20:48 +03:00
|
|
|
{
|
|
|
|
auto ten = ctorch::cast(tensor_handle);
|
2021-01-09 20:13:38 +03:00
|
|
|
ten.index(to_index(index, ten.dim())) = value;
|
2021-01-06 16:20:48 +03:00
|
|
|
}
|
|
|
|
|
2021-01-08 11:27:39 +03:00
|
|
|
template <typename Dtype>
|
|
|
|
inline torch::Tensor randn(std::vector<int64_t> shape, torch::Device device)
|
|
|
|
{
|
|
|
|
return torch::randn(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
|
|
|
}
|
|
|
|
|
2021-01-09 20:13:38 +03:00
|
|
|
template <typename Dtype>
|
|
|
|
inline torch::Tensor rand(std::vector<int64_t> shape, torch::Device device)
|
|
|
|
{
|
|
|
|
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
|
|
|
}
|
|
|
|
|
2021-01-10 19:24:57 +03:00
|
|
|
template <typename Dtype>
|
|
|
|
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
|
|
|
{
|
|
|
|
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
|
|
|
}
|
|
|
|
|
2020-12-30 01:42:33 +03:00
|
|
|
} // namespace ctorch
|