File: | build/../torch/csrc/tensor/python_tensor.cpp |
Warning: | line 356, column 36 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include <torch/csrc/tensor/python_tensor.h> | |||
2 | ||||
3 | #include <structmember.h> | |||
4 | #include <pybind11/pybind11.h> | |||
5 | ||||
6 | #include <torch/csrc/Dtype.h> | |||
7 | #include <torch/csrc/DynamicTypes.h> | |||
8 | #include <torch/csrc/Exceptions.h> | |||
9 | #include <torch/csrc/Layout.h> | |||
10 | #include <torch/csrc/autograd/variable.h> | |||
11 | #include <torch/csrc/autograd/python_variable.h> | |||
12 | #include <torch/csrc/autograd/generated/VariableType.h> | |||
13 | #include <torch/csrc/autograd/utils/wrap_outputs.h> | |||
14 | #include <torch/csrc/utils/cuda_enabled.h> | |||
15 | #include <torch/csrc/utils/cuda_lazy_init.h> | |||
16 | #include <torch/csrc/utils/python_strings.h> | |||
17 | #include <torch/csrc/utils/tensor_new.h> | |||
18 | #include <torch/csrc/utils/tensor_types.h> | |||
19 | ||||
20 | #include <ATen/ATen.h> | |||
21 | ||||
22 | #include <sstream> | |||
23 | #include <string> | |||
24 | #include <type_traits> | |||
25 | #include <vector> | |||
26 | ||||
27 | namespace torch { namespace tensors { | |||
28 | ||||
29 | using namespace at; | |||
30 | using namespace torch::autograd; | |||
31 | ||||
32 | struct PyTensorType { | |||
33 | PyTypeObject py_type; | |||
34 | THPDtype* dtype; | |||
35 | THPLayout* layout; | |||
36 | bool is_cuda; | |||
37 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) | |||
38 | char name[64]; | |||
39 | int backend; | |||
40 | int scalar_type; | |||
41 | ||||
42 | Backend get_backend() const { | |||
43 | return static_cast<Backend>(backend); | |||
44 | } | |||
45 | ||||
46 | DispatchKey get_dispatch_key() const { | |||
47 | return backendToDispatchKey(static_cast<Backend>(backend)); | |||
48 | } | |||
49 | ||||
50 | ScalarType get_scalar_type() const { | |||
51 | return static_cast<ScalarType>(scalar_type); | |||
52 | } | |||
53 | }; | |||
54 | ||||
55 | static_assert(std::is_standard_layout<PyTensorType>::value, "PyTensorType must be standard layout"); | |||
56 | ||||
57 | // This is always an instance of VariableType | |||
58 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
59 | static PyTensorType* default_tensor_type; | |||
60 | ||||
61 | static void py_bind_tensor_types(const std::vector<PyTensorType*>& tensor_types); | |||
62 | ||||
63 | static TypeError unavailable_type(const PyTensorType& type) { | |||
64 | return TypeError("type %s not available. Torch not compiled with CUDA enabled.", type.name); | |||
65 | } | |||
66 | ||||
67 | static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { | |||
68 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
69 | auto& tensor_type = *((PyTensorType*)type); | |||
70 | if (tensor_type.is_cuda && !torch::utils::cuda_enabled()) { | |||
71 | throw unavailable_type(tensor_type); | |||
72 | } | |||
73 | return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(tensor_type.get_dispatch_key(), tensor_type.get_scalar_type(), args, kwargs)); | |||
74 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
75 | } | |||
76 | ||||
77 | // TODO: Deprecate this instancecheck entirely. It's here to make | |||
78 | // instanceof(t, torch.FloatTensor) work, but we are not going to keep | |||
79 | // adding torch.QuantizedIntTensor classes for every new tensor type | |||
80 | // we add... | |||
81 | static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) { | |||
82 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
83 | auto self = (PyTensorType*)_self; | |||
84 | if (THPVariable_Check(arg)) { | |||
85 | const auto& var = THPVariable_Unpack(arg); | |||
86 | // NB: This is a little unfortunate, in that if I do an isinstance check | |||
87 | // against torch.cuda.FloatTensor, this will immediately initialize CUDA. | |||
88 | // I originally thought that it would not be possible for aten_type_ to | |||
89 | // be nullptr if you had a tensor of some type, in which case you can | |||
90 | // skip initializing aten_type(), but TestAutograd.test_type_conversions | |||
91 | // seems to violate this property (for whatever reason.) | |||
92 | // | |||
93 | // TODO: Stop using legacyExtractDispatchKey here (probably need to build | |||
94 | // in instanceof checking to Tensor class itself) | |||
95 | if (legacyExtractDispatchKey(var.key_set()) == self->get_dispatch_key() && | |||
96 | var.scalar_type() == static_cast<ScalarType>(self->scalar_type)) { | |||
97 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
98 | } | |||
99 | } | |||
100 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
101 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
102 | } | |||
103 | ||||
104 | PyObject *Tensor_dtype(PyTensorType* self, void *unused) { | |||
105 | return torch::autograd::utils::wrap(self->dtype); | |||
106 | } | |||
107 | ||||
108 | PyObject *Tensor_layout(PyTensorType* self, void *unused) { | |||
109 | return torch::autograd::utils::wrap(self->layout); | |||
110 | } | |||
111 | ||||
112 | PyObject *Tensor_is_cuda(PyTensorType* self, void *unused) { | |||
113 | if (self->is_cuda) { | |||
114 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
115 | } else { | |||
116 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
117 | } | |||
118 | } | |||
119 | ||||
120 | PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) { | |||
121 | if (self->layout->layout == at::Layout::Strided) { | |||
122 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
123 | } else { | |||
124 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
125 | } | |||
126 | } | |||
127 | ||||
128 | PyObject *Tensor_is_sparse_csr(PyTensorType *self, void *unused) { | |||
129 | if (self->layout->layout == at::Layout::SparseCsr) { | |||
130 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
131 | } else { | |||
132 | Py_RETURN_FALSEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct )))), ((PyObject *) &_Py_FalseStruct); | |||
133 | } | |||
134 | } | |||
135 | ||||
136 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) | |||
137 | static struct PyMethodDef metaclass_methods[] = { | |||
138 | {"__instancecheck__", Tensor_instancecheck, METH_O0x0008, nullptr}, | |||
139 | {nullptr} | |||
140 | }; | |||
141 | ||||
142 | typedef PyObject *(*getter)(PyObject *, void *); | |||
143 | ||||
144 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) | |||
145 | static struct PyGetSetDef metaclass_properties[] = { | |||
146 | {"dtype", (getter)Tensor_dtype, nullptr, nullptr, nullptr}, | |||
147 | {"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr}, | |||
148 | {"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr}, | |||
149 | {"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr}, | |||
150 | {"is_sparse_csr",(getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr}, | |||
151 | {nullptr} | |||
152 | }; | |||
153 | ||||
154 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
155 | static PyTypeObject metaclass = { | |||
156 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, | |||
157 | "torch.tensortype", /* tp_name */ | |||
158 | sizeof(PyTypeObject) /* tp_basicsize */ | |||
159 | }; | |||
160 | ||||
161 | static void py_initialize_metaclass(PyTypeObject& metaclass) { | |||
162 | metaclass.tp_flags = Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0) | Py_TPFLAGS_BASETYPE(1UL << 10); | |||
163 | metaclass.tp_methods = metaclass_methods; | |||
164 | metaclass.tp_getset = metaclass_properties; | |||
165 | metaclass.tp_base = &PyType_Type; | |||
166 | if (PyType_Ready(&metaclass) < 0) { | |||
167 | throw python_error(); | |||
168 | } | |||
169 | } | |||
170 | ||||
171 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
172 | static PyTypeObject tensor_type_prototype = { | |||
173 | PyVarObject_HEAD_INIT(&metaclass, 0){ { 1, &metaclass }, 0 }, | |||
174 | nullptr, /* tp_name */ | |||
175 | sizeof(PyTensorType) /* tp_basicsize */ | |||
176 | }; | |||
177 | ||||
178 | static void py_initialize_tensor_type(PyTypeObject& type, const char* name, PyObject* tp_dict) { | |||
179 | // NOTE: we don't use the typical static declaration of PyTypeObject because | |||
180 | // we need to initialize as many types as there are VariableType instances. | |||
181 | // We copy the basic object fields from a prototype definition and initialize | |||
182 | // the remaining fields below. | |||
183 | memcpy(&type, &tensor_type_prototype, sizeof(PyTypeObject)); | |||
184 | // Subclassing from torch.<ScalarType>Tensor isn't supported. | |||
185 | // (Py_TPFLAGS_BASETYPE omitted). Subclassing torch.Tensor still allowed. | |||
186 | type.tp_flags = Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0); | |||
187 | type.tp_name = name; | |||
188 | type.tp_new = Tensor_new; | |||
189 | if (PyType_Ready(&type) < 0) { | |||
190 | throw python_error(); | |||
191 | } | |||
192 | if (PyDict_Merge(type.tp_dict, tp_dict, 0) < 0) { | |||
193 | throw python_error(); | |||
194 | } | |||
195 | } | |||
196 | ||||
197 | static const char* get_module(Backend backend) { | |||
198 | switch (backend) { | |||
199 | case Backend::CPU: return "torch"; | |||
200 | case Backend::CUDA: return "torch.cuda"; | |||
201 | case Backend::SparseCPU: return "torch.sparse"; | |||
202 | case Backend::SparseCUDA: return "torch.cuda.sparse"; | |||
203 | default: AT_ERROR("invalid backend: ", toString(backend))do { ::c10::detail::deprecated_AT_ERROR(); if ((__builtin_expect (static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail ( __func__, "../torch/csrc/tensor/python_tensor.cpp", static_cast <uint32_t>(203), (::c10::detail::torchCheckMsgImpl( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", ::c10:: str("invalid backend: ", toString(backend))))); }; } while (false ); | |||
204 | } | |||
205 | } | |||
206 | ||||
207 | static std::string get_name(Backend backend, ScalarType scalarType) { | |||
208 | std::ostringstream ss; | |||
209 | ss << get_module(backend) << "." << toString(scalarType) << "Tensor"; | |||
210 | return ss.str(); | |||
211 | } | |||
212 | ||||
213 | static THPObjectPtr get_storage_obj(PyTensorType* type) { | |||
214 | auto module_name = get_module(type->get_backend()); | |||
215 | auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name)); | |||
216 | if (!module_obj) throw python_error(); | |||
217 | ||||
218 | auto storage_name = std::string(toString(type->get_scalar_type())) + "Storage"; | |||
219 | THPObjectPtr storage(PyObject_GetAttrString(module_obj.get(), storage_name.c_str())); | |||
220 | if (!storage.get()) { | |||
221 | throw TypeError("couldn't find storage object %s", storage_name.c_str()); | |||
222 | } | |||
223 | return storage; | |||
224 | } | |||
225 | ||||
226 | static void set_type(PyTensorType& type_obj, Backend backend, ScalarType scalarType) { | |||
227 | // This field is lazily initialized from backend and scalar_type | |||
228 | type_obj.backend = static_cast<int>(backend); | |||
229 | type_obj.scalar_type = static_cast<int>(scalarType); | |||
230 | type_obj.layout = torch::getTHPLayout(layout_from_backend(backend)); | |||
231 | type_obj.dtype = torch::getTHPDtype(scalarType); | |||
232 | type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA); | |||
233 | } | |||
234 | ||||
235 | static void set_name(PyTensorType& type_obj, const std::string& name) { | |||
236 | size_t n = sizeof(type_obj.name); | |||
237 | strncpy(type_obj.name, name.c_str(), n); | |||
238 | type_obj.name[n - 1] = '\0'; | |||
239 | } | |||
240 | ||||
241 | static THPObjectPtr get_tensor_dict() { | |||
242 | auto torch = THPObjectPtr(PyImport_ImportModule("torch")); | |||
243 | if (!torch) throw python_error(); | |||
244 | ||||
245 | auto tensor_class = THPObjectPtr(PyObject_GetAttrString(torch, "Tensor")); | |||
246 | if (!tensor_class) throw python_error(); | |||
247 | ||||
248 | auto tensor_type = (PyTypeObject*)tensor_class.get(); | |||
249 | TORCH_CHECK(tensor_type->tp_base, "missing base type for Tensor")if ((__builtin_expect(static_cast<bool>(!(tensor_type-> tp_base)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/tensor/python_tensor.cpp" , static_cast<uint32_t>(249), (::c10::detail::torchCheckMsgImpl ( "Expected " "tensor_type->tp_base" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)" , "missing base type for Tensor"))); }; | |||
250 | ||||
251 | auto res = THPObjectPtr(PyDict_New()); | |||
252 | if (!res) throw python_error(); | |||
253 | ||||
254 | if (PyDict_Merge(res.get(), tensor_type->tp_dict, 0) < 0) { | |||
255 | throw python_error(); | |||
256 | } | |||
257 | if (PyDict_Merge(res.get(), tensor_type->tp_base->tp_dict, 0) < 0) { | |||
258 | throw python_error(); | |||
259 | } | |||
260 | ||||
261 | return res; | |||
262 | } | |||
263 | ||||
264 | // A note about the lifetime of the various PyTensorType: normally | |||
265 | // PyTypeObject instances are statically allocated, but we want to create them | |||
266 | // dynamically at init time, because their exact number depends on | |||
267 | // torch::utils::all_declared_types(). The memory for each PyTensorType is | |||
268 | // allocated by initialize_aten_types() and never freed: technically it's a | |||
269 | // leak, but it's not a problem since we want them to be alive for the whole time | |||
270 | // of the process anyway. | |||
271 | // | |||
272 | // An alternative is to use a std::vector<PyTensorType> instead, and let | |||
273 | // std::vector to manage the lifetime of its items. This is problematic | |||
274 | // though, because it means that the memory of PyTensorType is deallocated at | |||
275 | // some point during the exit: if by chance we have another global destructor | |||
276 | // and/or atexit() function which tries to access the PyTensorTypes, we risk | |||
277 | // an use-after-free error. This happens for example if we embed CPython and | |||
278 | // call Py_Finalize inside an atexit() function which was registered before | |||
279 | // importing torch. | |||
280 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
281 | static std::vector<PyTensorType*> tensor_types; | |||
282 | ||||
283 | void set_default_tensor_type(PyTensorType* type) { | |||
284 | if (!at::isFloatingType(type->get_scalar_type())) { | |||
285 | throw TypeError("only floating-point types are supported as the default type"); | |||
286 | } | |||
287 | if (type->get_backend() == Backend::Undefined) { | |||
288 | throw TypeError("default type cannot be undefined"); | |||
289 | } | |||
290 | if (isSparse(type->get_backend())) { | |||
291 | throw TypeError("only dense types are supported as the default type"); | |||
292 | } | |||
293 | ||||
294 | // get the storage first, so if it doesn't exist we don't change the default tensor type | |||
295 | THPObjectPtr storage = get_storage_obj(type); | |||
296 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) | |||
297 | default_tensor_type = type; | |||
298 | at::set_default_dtype(scalarTypeToTypeMeta(type->get_scalar_type())); | |||
299 | ||||
300 | auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); | |||
301 | if (!torch_module) throw python_error(); | |||
302 | ||||
303 | if (PyObject_SetAttrString(torch_module.get(), "Storage", storage) != 0) { | |||
304 | // technically, we should undo the change of default tensor type. | |||
305 | throw python_error(); | |||
306 | } | |||
307 | } | |||
308 | ||||
309 | static void initialize_aten_types(std::vector<PyTensorType*>& tensor_types) { | |||
310 | // includes CUDA types even when PyTorch is not built with CUDA | |||
311 | auto declared_types = torch::utils::all_declared_types(); | |||
312 | tensor_types.resize(declared_types.size()); | |||
313 | ||||
314 | for (size_t i = 0, end = declared_types.size(); i != end; i++) { | |||
315 | tensor_types[i] = new PyTensorType(); | |||
316 | auto& tensor_type = *tensor_types[i]; | |||
317 | Backend backend = declared_types[i].first; | |||
318 | ScalarType scalar_type = declared_types[i].second; | |||
319 | set_type(tensor_type, backend, scalar_type); | |||
320 | set_name(tensor_type, get_name(backend, scalar_type)); | |||
321 | ||||
322 | // Use torch.float32 as the default tensor type | |||
323 | if (backend == Backend::CPU && scalar_type == at::kFloat) { | |||
324 | set_default_tensor_type(&tensor_type); | |||
325 | } | |||
326 | } | |||
327 | } | |||
328 | ||||
329 | void initialize_python_bindings() { | |||
330 | // Initialize the at::Type* pointers, name, and properties of the PyTensorType | |||
331 | // vector. After this call, the vector must not be resized. | |||
332 | initialize_aten_types(tensor_types); | |||
333 | ||||
334 | // Initialize the Python metaclass for the torch.FloatTensor, etc. types. | |||
335 | // The metaclass handles __instancecheck__ checks and binds the dtype property | |||
336 | // on the type objects. | |||
337 | py_initialize_metaclass(metaclass); | |||
338 | ||||
339 | // Get the tp_dict of the Variable class. We copy function definitions | |||
340 | // onto each Tensor type object so that they can be accessed via e.g. | |||
341 | // `torch.FloatTensor.add`. | |||
342 | auto tensor_dict = get_tensor_dict(); | |||
343 | ||||
344 | // Initialize each Python type object torch.FloatTensor, torch.DoubleTensor, etc. | |||
345 | for (auto& tensor_type : tensor_types) { | |||
346 | py_initialize_tensor_type(tensor_type->py_type, tensor_type->name, tensor_dict.get()); | |||
347 | } | |||
348 | ||||
349 | // Add the type objects to their corresponding modules. e.g. torch.FloatTensor | |||
350 | // is added to the `torch` module as `FloatTensor`. Also add all the type | |||
351 | // objects to the set torch._tensor_classes. | |||
352 | py_bind_tensor_types(tensor_types); | |||
| ||||
353 | } | |||
354 | ||||
355 | static void py_bind_tensor_types(const std::vector<PyTensorType*>& tensor_types) { | |||
356 | auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); | |||
| ||||
357 | if (!torch_module) throw python_error(); | |||
358 | ||||
359 | auto tensor_classes = THPObjectPtr(PyObject_GetAttrString(torch_module.get(), "_tensor_classes")); | |||
360 | if (!tensor_classes) throw python_error(); | |||
361 | ||||
362 | for (auto& tensor_type : tensor_types) { | |||
363 | auto name = std::string(tensor_type->name); | |||
364 | auto idx = name.rfind('.'); | |||
365 | auto type_name = name.substr(idx + 1); | |||
366 | auto module_name = name.substr(0, idx); | |||
367 | ||||
368 | auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name.c_str())); | |||
369 | if (!module_obj) throw python_error(); | |||
370 | ||||
371 | PyObject* type_obj = (PyObject*)tensor_type; | |||
372 | Py_INCREF(type_obj)_Py_INCREF(((PyObject*)(type_obj))); | |||
373 | if (PyModule_AddObject(module_obj.get(), type_name.c_str(), type_obj) < 0) { | |||
374 | throw python_error(); | |||
375 | } | |||
376 | if (PySet_Add(tensor_classes.get(), type_obj) < 0) { | |||
377 | throw python_error(); | |||
378 | } | |||
379 | } | |||
380 | } | |||
381 | ||||
382 | static bool PyTensorType_Check(PyObject* obj) { | |||
383 | auto it = std::find_if(tensor_types.begin(), tensor_types.end(), | |||
384 | [obj](PyTensorType *x) { | |||
385 | return (PyObject*)x == obj; | |||
386 | }); | |||
387 | return it != tensor_types.end(); | |||
388 | } | |||
389 | ||||
390 | void py_set_default_tensor_type(PyObject* obj) { | |||
391 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | |||
392 | PyTensorType *type; | |||
393 | if (PyTensorType_Check(obj)) { | |||
394 | type = (PyTensorType*)obj; | |||
395 | } else { | |||
396 | throw TypeError("invalid type object"); | |||
397 | } | |||
398 | if (type->is_cuda && !torch::utils::cuda_enabled()) { | |||
399 | throw unavailable_type(*type); | |||
400 | } | |||
401 | set_default_tensor_type(type); | |||
402 | } | |||
403 | ||||
404 | void py_set_default_dtype(PyObject* obj) { | |||
405 | if (THPDtype_Check(obj)) { | |||
406 | auto scalar_type = ((THPDtype*)obj)->scalar_type; | |||
407 | auto backend = default_tensor_type->get_backend(); | |||
408 | auto it = std::find_if(tensor_types.begin(), tensor_types.end(), | |||
409 | [backend, scalar_type](PyTensorType *x) { | |||
410 | return x->get_backend() == backend && x->get_scalar_type() == scalar_type; | |||
411 | }); | |||
412 | set_default_tensor_type(*it); | |||
413 | } else { | |||
414 | throw TypeError("invalid dtype object"); | |||
415 | } | |||
416 | } | |||
417 | ||||
418 | c10::DispatchKey get_default_dispatch_key() { | |||
419 | AT_ASSERT(default_tensor_type)do { ::c10::detail::deprecated_AT_ASSERT(); if ((__builtin_expect (static_cast<bool>(!(default_tensor_type)), 0))) { ::c10 ::detail::torchInternalAssertFail( __func__, "../torch/csrc/tensor/python_tensor.cpp" , static_cast<uint32_t>(419), "default_tensor_type" "INTERNAL ASSERT FAILED at " "\"../torch/csrc/tensor/python_tensor.cpp\"" ":" "419" ", please report a bug to PyTorch. " , c10::str()); }; } while (false); | |||
420 | return default_tensor_type->get_dispatch_key(); | |||
421 | } | |||
422 | ||||
423 | ScalarType get_default_scalar_type() { | |||
424 | return typeMetaToScalarType(get_default_dtype()); | |||
425 | } | |||
426 | ||||
427 | }} // namespace torch::tensors |
1 | #ifndef PyImport_ImportModule |
2 | struct _object; |
3 | typedef struct _object PyObject; |
4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
5 | PyObject* PyImport_ImportModule(const char *name) { |
6 | return clang_analyzer_PyObject_New_Reference(); |
7 | } |
8 | #else |
9 | #warning "API PyImport_ImportModule is defined as a macro." |
10 | #endif |