#include #include "ctorch.h" namespace ctorch { template inline c10::ScalarType dtype() { return torch::kFloat64; } template <> inline c10::ScalarType dtype() { return torch::kFloat32; } template <> inline c10::ScalarType dtype() { return torch::kInt64; } template <> inline c10::ScalarType dtype() { return torch::kInt32; } inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle) { return *static_cast(tensor_handle); } inline int device_to_int(const torch::Tensor &tensor) { return (tensor.device().type() == torch::kCPU) ? 0 : 1 + tensor.device().index(); } inline torch::Device int_to_device(int device_int) { return (device_int == 0) ? torch::kCPU : torch::Device(torch::kCUDA, device_int - 1); } inline std::vector to_vec_int(int *arr, int arr_size) { auto vec = std::vector(arr_size); vec.assign(arr, arr + arr_size); return vec; } inline std::vector to_index(int *arr, int arr_size) { std::vector index; for (int i = 0; i < arr_size; i++) { index.emplace_back(arr[i]); } return index; } template inline torch::Tensor copy_from_blob(Dtype *data, std::vector shape, torch::Device device) { return torch::from_blob(data, shape, dtype()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true); } template inline NumType get(const TorchTensorHandle &tensor_handle, int *index) { auto ten = ctorch::cast(tensor_handle); return ten.index(to_index(index, ten.dim())).item(); } template inline void set(TorchTensorHandle &tensor_handle, int *index, NumType value) { auto ten = ctorch::cast(tensor_handle); ten.index(to_index(index, ten.dim())) = value; } template inline torch::Tensor randn(std::vector shape, torch::Device device) { return torch::randn(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } template inline torch::Tensor rand(std::vector shape, torch::Device device) { return torch::rand(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } } // namespace ctorch