#include #include "ctorch.h" namespace ctorch { template inline c10::ScalarType dtype() { return torch::kFloat64; } template <> inline c10::ScalarType dtype() { return torch::kFloat32; } template <> inline c10::ScalarType dtype() { return torch::kInt64; } template <> inline c10::ScalarType dtype() { return torch::kInt32; } inline torch::Tensor &cast(TorchTensorHandle tensor_handle) { return *static_cast(tensor_handle); } template inline torch::Tensor copy_from_blob(Dtype *data, int *shape, int dim) { auto shape_vec = std::vector(dim); shape_vec.assign(shape, shape + dim); return torch::from_blob(data, shape_vec, dtype()).clone(); } template inline int *to_dynamic_ints(IntArray arr) { size_t n = arr.size(); int *res = (int *)malloc(sizeof(int) * n); for (size_t i = 0; i < n; i++) { res[i] = arr[i]; } return res; } } // namespace ctorch