#include #include "ctorch.h" namespace ctorch { template inline c10::ScalarType dtype() { return torch::kFloat64; } template <> inline c10::ScalarType dtype() { return torch::kFloat32; } template <> inline c10::ScalarType dtype() { return torch::kInt64; } template <> inline c10::ScalarType dtype() { return torch::kInt32; } inline torch::Tensor &cast(const TorchTensorHandle &tensor_handle) { return *static_cast(tensor_handle); } inline std::vector to_vec_int(int *arr, int arr_size) { auto vec = std::vector(arr_size); vec.assign(arr, arr + arr_size); return vec; } template inline torch::Tensor copy_from_blob(Dtype *data, std::vector shape, torch::Device device) { return torch::from_blob(data, shape, dtype()).to(torch::TensorOptions().layout(torch::kStrided).device(device), false, true); } inline int *to_dynamic_ints(const c10::IntArrayRef &arr) { size_t n = arr.size(); int *res = (int *)malloc(sizeof(int) * n); for (size_t i = 0; i < n; i++) { res[i] = arr[i]; } return res; } inline std::vector offset_to_index(int offset, const c10::IntArrayRef &strides) { std::vector index; for (const auto &stride : strides) { index.emplace_back(offset / stride); offset %= stride; } return index; } template inline NumType get_at_offset(const TorchTensorHandle &tensor_handle, int offset) { auto ten = ctorch::cast(tensor_handle); return ten.index(ctorch::offset_to_index(offset, ten.strides())).item(); } template inline void set_at_offset(TorchTensorHandle &tensor_handle, int offset, NumType value) { auto ten = ctorch::cast(tensor_handle); ten.index(offset_to_index(offset, ten.strides())) = value; } template inline torch::Tensor randn(std::vector shape, torch::Device device) { return torch::randn(shape, torch::TensorOptions().dtype(dtype()).layout(torch::kStrided).device(device)); } } // namespace ctorch