| File: | build/../torch/csrc/autograd/python_function.cpp |
| Warning: | line 84, column 18 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include <torch/csrc/autograd/python_function.h> | |||
| 2 | ||||
| 3 | #include <torch/csrc/python_headers.h> | |||
| 4 | #include <structmember.h> | |||
| 5 | #include <ATen/ATen.h> | |||
| 6 | #include <ATen/SequenceNumber.h> | |||
| 7 | #include <c10/util/irange.h> | |||
| 8 | #include <pybind11/pybind11.h> | |||
| 9 | ||||
| 10 | #include <torch/csrc/THP.h> | |||
| 11 | #include <torch/csrc/autograd/grad_mode.h> | |||
| 12 | #include <torch/csrc/autograd/functions/accumulate_grad.h> | |||
| 13 | #include <torch/csrc/autograd/functions/basic_ops.h> | |||
| 14 | #include <torch/csrc/autograd/functions/utils.h> | |||
| 15 | #include <torch/csrc/autograd/python_cpp_function.h> | |||
| 16 | #include <torch/csrc/autograd/python_hook.h> | |||
| 17 | #include <torch/csrc/autograd/saved_variable.h> | |||
| 18 | #include <torch/csrc/autograd/python_anomaly_mode.h> | |||
| 19 | #include <torch/csrc/jit/frontend/tracer.h> | |||
| 20 | #include <torch/csrc/jit/ir/ir.h> | |||
| 21 | #include <torch/csrc/jit/python/python_tracer.h> | |||
| 22 | #include <torch/csrc/jit/python/pybind_utils.h> | |||
| 23 | #include <torch/csrc/utils/python_strings.h> | |||
| 24 | #include <torch/csrc/DynamicTypes.h> | |||
| 25 | #include <torch/csrc/Exceptions.h> | |||
| 26 | ||||
| 27 | #include <exception> | |||
| 28 | #include <functional> | |||
| 29 | #include <memory> | |||
| 30 | #include <stdexcept> | |||
| 31 | #include <string> | |||
| 32 | #include <tuple> | |||
| 33 | #include <unordered_map> | |||
| 34 | #include <unordered_set> | |||
| 35 | #include <utility> | |||
| 36 | #include <vector> | |||
| 37 | ||||
| 38 | using namespace torch; | |||
| 39 | using namespace torch::autograd; | |||
| 40 | using namespace torch::jit; | |||
| 41 | using at::Tensor; | |||
| 42 | ||||
| 43 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
| 44 | PyObject *THPFunctionClass = nullptr; | |||
| 45 | ||||
| 46 | #define THPFunction_assert(condition, ...)if (!(condition)) { THPUtils_setError(...); throw python_error (); } \ | |||
| 47 | if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } | |||
| 48 | ||||
| 49 | namespace torch { namespace autograd { | |||
| 50 | ||||
| 51 | void PyNode::throw_python_error() { | |||
| 52 | python_error err; | |||
| 53 | err.persist(); | |||
| 54 | throw err; | |||
| 55 | } | |||
| 56 | ||||
| 57 | // NOTE: this function is written in a way that assumes it's only called for backward; | |||
| 58 | // it's used by engine.cpp. This is responsible for forwarding a call from | |||
| 59 | // C++'s Node::apply to a Python method "apply". | |||
| 60 | auto PyNode::apply(variable_list&& inputs) -> variable_list { | |||
| 61 | pybind11::gil_scoped_acquire gil; | |||
| 62 | at::OptionalDeviceGuard _device_guard; | |||
| 63 | THPFunction* py_fn = (THPFunction*)obj; | |||
| 64 | ||||
| 65 | // Massage a C++ variable_list into a Python arguments tuple | |||
| 66 | auto num_inputs = inputs.size(); | |||
| 67 | THPObjectPtr pyInputs(PyTuple_New(num_inputs)); | |||
| 68 | if (!pyInputs) throw_python_error(); | |||
| ||||
| 69 | auto& output_info = py_fn->output_info; | |||
| 70 | for (const auto i : c10::irange(num_inputs)) { | |||
| 71 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | |||
| 72 | PyObject* input; | |||
| 73 | if (inputs[i].defined() || !py_fn->materialize_grads) { | |||
| 74 | input = THPVariable_Wrap(inputs[i]); | |||
| 75 | } else { | |||
| 76 | input = THPVariable_Wrap(output_info[i].zeros(_device_guard)); | |||
| 77 | } | |||
| 78 | if (!input) throw_python_error(); | |||
| 79 | PyTuple_SET_ITEM(pyInputs.get(), i, input)PyTuple_SetItem(pyInputs.get(), i, input); | |||
| 80 | } | |||
| 81 | ||||
| 82 | THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); | |||
| 83 | if (!apply_fn) throw_python_error(); | |||
| 84 | THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); | |||
| ||||
| 85 | if (!r) throw_python_error(); | |||
| 86 | ensure_tuple(r); | |||
| 87 | ||||
| 88 | auto& is_variable_input = py_fn->is_variable_input; | |||
| 89 | int num_outputs = PyTuple_GET_SIZE(r.get())(((PyVarObject*)(((PyTupleObject *)(r.get()))))->ob_size); | |||
| 90 | int num_forward_inputs = is_variable_input.size(); | |||
| 91 | // Returning too many results is ok, but only as long as they're all None. | |||
| 92 | // Truncate the result tuple in that case. | |||
| 93 | if (num_outputs > num_forward_inputs) { | |||
| 94 | bool all_none = true; | |||
| 95 | for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { | |||
| 96 | all_none &= PyTuple_GET_ITEM(r.get(), i)(((PyTupleObject *)(r.get()))->ob_item[i]) == Py_None(&_Py_NoneStruct); | |||
| 97 | } | |||
| 98 | if (all_none) { | |||
| 99 | num_outputs = num_forward_inputs; | |||
| 100 | r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs); | |||
| 101 | if (!r) throw_python_error(); | |||
| 102 | } | |||
| 103 | } | |||
| 104 | ||||
| 105 | // Now the number of gradients should match | |||
| 106 | if (num_outputs != num_forward_inputs) { | |||
| 107 | std::string msg("function "); | |||
| 108 | msg += name() + " returned an incorrect number of gradients (expected "; | |||
| 109 | msg += std::to_string(num_forward_inputs) + ", got " ; | |||
| 110 | msg += std::to_string(num_outputs) + ")"; | |||
| 111 | throw std::runtime_error(msg); | |||
| 112 | } | |||
| 113 | ||||
| 114 | // Massage the Python results tuple back into a C++ variable_list | |||
| 115 | variable_list results; | |||
| 116 | results.reserve(num_outputs); | |||
| 117 | for (int i = 0; i != num_outputs; ++i) { | |||
| 118 | PyObject* output = PyTuple_GET_ITEM(r.get(), i)(((PyTupleObject *)(r.get()))->ob_item[i]); | |||
| 119 | bool was_variable = is_variable_input[i]; | |||
| 120 | if (!was_variable) { | |||
| 121 | if (output != Py_None(&_Py_NoneStruct)) { | |||
| 122 | std::string msg("function "); | |||
| 123 | msg += name() + " returned a gradient different than None at position "; | |||
| 124 | msg += std::to_string(i + 1) + ", but the corresponding forward input was not a Variable"; | |||
| 125 | throw std::runtime_error(msg); | |||
| 126 | } | |||
| 127 | continue; | |||
| 128 | } | |||
| 129 | if (output == Py_None(&_Py_NoneStruct)) { | |||
| 130 | results.emplace_back(); | |||
| 131 | } else { | |||
| 132 | if (!THPVariable_Check(output)) { | |||
| 133 | std::string msg("expected Variable or None (got "); | |||
| 134 | msg += THPUtils_typename(output)((((PyObject*)(output))->ob_type)->tp_name); | |||
| 135 | msg += ")"; | |||
| 136 | throw std::runtime_error(msg); | |||
| 137 | } | |||
| 138 | results.emplace_back(THPVariable_Unpack(output)); | |||
| 139 | } | |||
| 140 | } | |||
| 141 | ||||
| 142 | return results; | |||
| 143 | } | |||
| 144 | ||||
| 145 | auto PyNode::is_traceable() -> bool { | |||
| 146 | pybind11::gil_scoped_acquire gil; | |||
| 147 | THPObjectPtr forward_class {PyObject_GetAttrString(obj, "_forward_cls")}; | |||
| 148 | if (!forward_class) throw_python_error(); | |||
| 149 | THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class, "is_traceable")}; | |||
| 150 | if (!traceable_py_bool) throw_python_error(); | |||
| 151 | return traceable_py_bool == Py_True((PyObject *) &_Py_TrueStruct); | |||
| 152 | } | |||
| 153 | ||||
| 154 | auto PyNode::release_variables() -> void { | |||
| 155 | pybind11::gil_scoped_acquire gil; | |||
| 156 | auto f = (THPFunction*) obj; | |||
| 157 | f->saved_variables.clear(); | |||
| 158 | f->has_freed_buffers = 1; | |||
| 159 | } | |||
| 160 | ||||
| 161 | auto PyNode::name() const -> std::string { | |||
| 162 | pybind11::gil_scoped_acquire gil; | |||
| 163 | auto f = (THPFunction*) obj; | |||
| 164 | auto name = std::string(Py_TYPE(f)(((PyObject*)(f))->ob_type)->tp_name); | |||
| 165 | return name; | |||
| 166 | } | |||
| 167 | ||||
| 168 | }} // namespace torch::autograd | |||
| 169 | ||||
| 170 | // Traverse and clear are required for supporting Python's GC cycle handling. | |||
| 171 | static int THPFunction_traverse(THPFunction *self, visitproc visit, void *arg) | |||
| 172 | { | |||
| 173 | // cdata could be null if the PyNode has already gone out of scope | |||
| 174 | // by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn only). | |||
| 175 | // | |||
| 176 | // TODO: I'm not really sure if we're actually obligated to traverse PyObject | |||
| 177 | // that is stored in PyNode, since we don't really own that C++ object. | |||
| 178 | if (auto cdata = self->cdata.lock()) { | |||
| 179 | for (const auto& hook : cdata->pre_hooks()) { | |||
| 180 | if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { | |||
| 181 | Py_VISIT(pyhook->dict)do { if (pyhook->dict) { int vret = visit(((PyObject*)(pyhook ->dict)), arg); if (vret) return vret; } } while (0); | |||
| 182 | } | |||
| 183 | } | |||
| 184 | for (const auto& hook : cdata->post_hooks()) { | |||
| 185 | if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { | |||
| 186 | Py_VISIT(pyhook->dict)do { if (pyhook->dict) { int vret = visit(((PyObject*)(pyhook ->dict)), arg); if (vret) return vret; } } while (0); | |||
| 187 | } | |||
| 188 | } | |||
| 189 | } | |||
| 190 | Py_VISIT(self->to_save)do { if (self->to_save) { int vret = visit(((PyObject*)(self ->to_save)), arg); if (vret) return vret; } } while (0); | |||
| 191 | Py_VISIT(self->non_differentiable)do { if (self->non_differentiable) { int vret = visit(((PyObject *)(self->non_differentiable)), arg); if (vret) return vret ; } } while (0); | |||
| 192 | Py_VISIT(self->dirty_tensors)do { if (self->dirty_tensors) { int vret = visit(((PyObject *)(self->dirty_tensors)), arg); if (vret) return vret; } } while (0); | |||
| 193 | return 0; | |||
| 194 | } | |||
| 195 | ||||
| 196 | static int THPFunction_clear(THPFunction *self) | |||
| 197 | { | |||
| 198 | // Note that the cdata might not be expired yet in the case where this | |||
| 199 | // object is part of a cycle and the GC happens to tp_clear this PyObject | |||
| 200 | // before the other ones that trigger the de-allocation of the cdata | |||
| 201 | ||||
| 202 | Py_CLEAR(self->needs_input_grad)do { PyObject *_py_tmp = ((PyObject*)(self->needs_input_grad )); if (_py_tmp != __null) { (self->needs_input_grad) = __null ; _Py_DECREF(((PyObject*)(_py_tmp))); } } while (0); | |||
| 203 | ||||
| 204 | Py_CLEAR(self->to_save)do { PyObject *_py_tmp = ((PyObject*)(self->to_save)); if ( _py_tmp != __null) { (self->to_save) = __null; _Py_DECREF( ((PyObject*)(_py_tmp))); } } while (0); | |||
| 205 | Py_CLEAR(self->non_differentiable)do { PyObject *_py_tmp = ((PyObject*)(self->non_differentiable )); if (_py_tmp != __null) { (self->non_differentiable) = __null ; _Py_DECREF(((PyObject*)(_py_tmp))); } } while (0); | |||
| 206 | Py_CLEAR(self->dirty_tensors)do { PyObject *_py_tmp = ((PyObject*)(self->dirty_tensors) ); if (_py_tmp != __null) { (self->dirty_tensors) = __null ; _Py_DECREF(((PyObject*)(_py_tmp))); } } while (0); | |||
| 207 | ||||
| 208 | self->output_info.clear(); | |||
| 209 | self->input_info.clear(); | |||
| 210 | self->saved_variables.clear(); | |||
| 211 | self->is_variable_input.clear(); | |||
| 212 | ||||
| 213 | return 0; | |||
| 214 | } | |||
| 215 | ||||
| 216 | static void THPFunction_dealloc(THPFunction* self) | |||
| 217 | { | |||
| 218 | // Why is this guaranteed to be true? Suppose that self->cdata is non-null | |||
| 219 | // (otherwise the condition is trivially true). Then there is a PyNode | |||
| 220 | // which contains an owning reference to this object. But we are only | |||
| 221 | // allowed to clear if all owning references are gone! Contradiction. | |||
| 222 | // | |||
| 223 | // However, note that THPFunction_clear is typically called in the shared_ptr | |||
| 224 | // destructor of PyNode; in that case, per | |||
| 225 | // https://cplusplus.github.io/LWG/lwg-active.html#2751 it's not currently | |||
| 226 | // specified in the standard that this is guaranteed. If you see this | |||
| 227 | // assert triggering in the wild, feel free to comment it out. They're | |||
| 228 | // likely to standardize that you ARE guaranteed to see the weak pointers | |||
| 229 | // as expired in the destructor in the future, so we'll keep this for now. | |||
| 230 | TORCH_INTERNAL_ASSERT(self->cdata.expired())if ((__builtin_expect(static_cast<bool>(!(self->cdata .expired())), 0))) { ::c10::detail::torchInternalAssertFail( __func__ , "../torch/csrc/autograd/python_function.cpp", static_cast< uint32_t>(230), "self->cdata.expired()" "INTERNAL ASSERT FAILED at " "\"../torch/csrc/autograd/python_function.cpp\"" ":" "230" ", please report a bug to PyTorch. " , c10::str()); }; | |||
| 231 | ||||
| 232 | PyObject_GC_UnTrack(self); | |||
| 233 | THPFunction_clear(self); | |||
| 234 | self->cdata.~weak_ptr<PyNode>(); | |||
| 235 | self->output_info.~vector(); | |||
| 236 | self->input_info.~vector(); | |||
| 237 | self->saved_variables.~vector(); | |||
| 238 | self->is_variable_input.~vector(); | |||
| 239 | Py_TYPE(self)(((PyObject*)(self))->ob_type)->tp_free((PyObject*)self); | |||
| 240 | } | |||
| 241 | ||||
| 242 | PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) | |||
| 243 | { | |||
| 244 | PyObject* obj = type->tp_alloc(type, 0); | |||
| 245 | if (!obj) return nullptr; | |||
| 246 | // Python zero-initializes the object memory, so there's no need to initialize | |||
| 247 | // most fields | |||
| 248 | THPFunction* self = (THPFunction*)obj; | |||
| 249 | // Setup the PyNode later; we can't keep it live here | |||
| 250 | new (&self->cdata) std::weak_ptr<PyNode>(); | |||
| 251 | new (&self->output_info) std::vector<VariableInfo>(); | |||
| 252 | new (&self->input_info) std::vector<VariableInfo>(); | |||
| 253 | new (&self->saved_variables) std::vector<SavedVariable>(); | |||
| 254 | new (&self->is_variable_input) std::vector<bool>(); | |||
| 255 | self->materialize_grads = true; | |||
| 256 | return obj; | |||
| 257 | } | |||
| 258 | ||||
| 259 | //////////////////////////////////////////////////////////////////////////////// | |||
| 260 | // Forward | |||
| 261 | //////////////////////////////////////////////////////////////////////////////// | |||
| 262 | ||||
| 263 | // Bump the counters of all recorded dirty input tensors, adding each of them | |||
| 264 | // into dirty_inputs. Also does some sanity checking. | |||
| 265 | static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction *self) | |||
| 266 | { | |||
| 267 | // Increase versions of modified tensors | |||
| 268 | std::unordered_set<at::TensorImpl*> dirty_inputs; | |||
| 269 | if (!self->dirty_tensors) return dirty_inputs; | |||
| 270 | ||||
| 271 | THPFunction_assert(PyTuple_Check(self->dirty_tensors), "autograd "if (!(((((((PyObject*)(self->dirty_tensors))->ob_type)) ->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: dirty_tensors attribute is expected to be a tuple " "but is %s", ((((PyObject*)(self->dirty_tensors))->ob_type )->tp_name)); throw python_error(); } | |||
| 272 | "internal error: dirty_tensors attribute is expected to be a tuple "if (!(((((((PyObject*)(self->dirty_tensors))->ob_type)) ->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: dirty_tensors attribute is expected to be a tuple " "but is %s", ((((PyObject*)(self->dirty_tensors))->ob_type )->tp_name)); throw python_error(); } | |||
| 273 | "but is %s", THPUtils_typename(self->dirty_tensors))if (!(((((((PyObject*)(self->dirty_tensors))->ob_type)) ->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: dirty_tensors attribute is expected to be a tuple " "but is %s", ((((PyObject*)(self->dirty_tensors))->ob_type )->tp_name)); throw python_error(); }; | |||
| 274 | Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors)(((PyVarObject*)(((PyTupleObject *)(self->dirty_tensors))) )->ob_size); | |||
| 275 | dirty_inputs.reserve(num_dirty); | |||
| 276 | for(const auto i : c10::irange(num_dirty)) { | |||
| 277 | PyObject *obj = PyTuple_GET_ITEM(self->dirty_tensors, i)(((PyTupleObject *)(self->dirty_tensors))->ob_item[i]); | |||
| 278 | THPFunction_assert(THPVariable_Check(obj), "mark_dirty can "if (!(THPVariable_Check(obj))) { THPUtils_setError("mark_dirty can " "only accept variables, but argument %d is of type %s", i, ( (((PyObject*)(obj))->ob_type)->tp_name)); throw python_error (); } | |||
| 279 | "only accept variables, but argument %d is of type %s", i,if (!(THPVariable_Check(obj))) { THPUtils_setError("mark_dirty can " "only accept variables, but argument %d is of type %s", i, ( (((PyObject*)(obj))->ob_type)->tp_name)); throw python_error (); } | |||
| 280 | THPUtils_typename(obj))if (!(THPVariable_Check(obj))) { THPUtils_setError("mark_dirty can " "only accept variables, but argument %d is of type %s", i, ( (((PyObject*)(obj))->ob_type)->tp_name)); throw python_error (); }; | |||
| 281 | ||||
| 282 | const auto& tensor = THPVariable_Unpack(obj); | |||
| 283 | dirty_inputs.insert(tensor.unsafeGetTensorImpl()); | |||
| 284 | torch::autograd::impl::bump_version(tensor); | |||
| 285 | } | |||
| 286 | // We're not going to ever need this so let's remove references now | |||
| 287 | Py_CLEAR(self->dirty_tensors)do { PyObject *_py_tmp = ((PyObject*)(self->dirty_tensors) ); if (_py_tmp != __null) { (self->dirty_tensors) = __null ; _Py_DECREF(((PyObject*)(_py_tmp))); } } while (0); | |||
| 288 | return dirty_inputs; | |||
| 289 | } | |||
| 290 | ||||
| 291 | static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(THPFunction *self); | |||
| 292 | ||||
| 293 | // Given a Python tuple of raw output tensors (raw_output), set each of | |||
| 294 | // the corresponding entries in a different Python tuple (outputs) with | |||
| 295 | // these tensors wrapped with variables. We save the gradient function (self) | |||
| 296 | // to the variable if the output requires grad. | |||
| 297 | // | |||
| 298 | // There is a considerable amount of complexity to handle if the operation | |||
| 299 | // that produced these output tensors is inplace. A mapping of *input* | |||
| 300 | // tensors to variables (t2var) is used to test if this occurred, and | |||
| 301 | // the set of dirty tensors (dirty_inputs) is used to figure out what to | |||
| 302 | // do in this case. After this method is run, t2var is extended with | |||
| 303 | // mappings for output tensors as well. | |||
| 304 | static void _wrap_outputs(const std::shared_ptr<PyNode>& cdata, THPFunction *self, | |||
| 305 | const variable_list &input_vars, PyObject *raw_output, PyObject *outputs, bool is_executable) | |||
| 306 | { | |||
| 307 | auto cdata_if_executable = is_executable ? cdata : nullptr; | |||
| 308 | Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output)(((PyVarObject*)(((PyTupleObject *)(raw_output))))->ob_size ); | |||
| 309 | if (is_executable) { | |||
| 310 | self->output_info.clear(); | |||
| 311 | self->output_info.reserve(num_outputs); | |||
| 312 | } | |||
| 313 | ||||
| 314 | auto non_differentiable = _parse_non_differentiable(self); | |||
| 315 | auto dirty_inputs = _mark_dirty(self); | |||
| 316 | ||||
| 317 | std::vector<c10::optional<Variable>> raw_output_vars; | |||
| 318 | raw_output_vars.reserve(num_outputs); | |||
| 319 | for (const auto i : c10::irange(num_outputs)) { | |||
| 320 | PyObject* obj = PyTuple_GET_ITEM(raw_output, i)(((PyTupleObject *)(raw_output))->ob_item[i]); | |||
| 321 | // Only process tensors as outputs for autograd purposes. | |||
| 322 | if (THPVariable_Check(obj)) { | |||
| 323 | raw_output_vars.emplace_back(THPVariable_Unpack(obj)); | |||
| 324 | } else { | |||
| 325 | raw_output_vars.emplace_back(); | |||
| 326 | } | |||
| 327 | } | |||
| 328 | ||||
| 329 | // Wrap only the tensor outputs. | |||
| 330 | auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata_if_executable); | |||
| 331 | ||||
| 332 | for(const auto i : c10::irange(num_outputs)) { | |||
| 333 | PyObject* obj = PyTuple_GetItem(raw_output, i); | |||
| 334 | // Keep the non-tensor outputs as is. | |||
| 335 | if (!THPVariable_Check(obj)) { | |||
| 336 | if (is_executable) { | |||
| 337 | self->output_info.emplace_back(); | |||
| 338 | } | |||
| 339 | Py_INCREF(obj)_Py_INCREF(((PyObject*)(obj))); | |||
| 340 | PyTuple_SetItem(outputs, i, obj); | |||
| 341 | } else { | |||
| 342 | if (is_executable) { | |||
| 343 | self->output_info.emplace_back(*wrapped_outputs[i]); | |||
| 344 | } | |||
| 345 | PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i])); | |||
| 346 | } | |||
| 347 | } | |||
| 348 | } | |||
| 349 | ||||
| 350 | // Save any variables that requested by to_save | |||
| 351 | static void _save_variables(const std::shared_ptr<PyNode>& cdata_ptr, THPFunction* self) | |||
| 352 | { | |||
| 353 | if (!self->to_save) return; | |||
| 354 | ||||
| 355 | THPFunction_assert(PyTuple_Check(self->to_save), "autograd internal "if (!(((((((PyObject*)(self->to_save))->ob_type))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError("autograd internal " "error: to_save attribute is expected to be a tuple but is %s" , ((((PyObject*)(self->to_save))->ob_type)->tp_name) ); throw python_error(); } | |||
| 356 | "error: to_save attribute is expected to be a tuple but is %s",if (!(((((((PyObject*)(self->to_save))->ob_type))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError("autograd internal " "error: to_save attribute is expected to be a tuple but is %s" , ((((PyObject*)(self->to_save))->ob_type)->tp_name) ); throw python_error(); } | |||
| 357 | THPUtils_typename(self->to_save))if (!(((((((PyObject*)(self->to_save))->ob_type))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError("autograd internal " "error: to_save attribute is expected to be a tuple but is %s" , ((((PyObject*)(self->to_save))->ob_type)->tp_name) ); throw python_error(); }; | |||
| 358 | Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save)(((PyVarObject*)(((PyTupleObject *)(self->to_save))))-> ob_size); | |||
| 359 | self->saved_variables.clear(); | |||
| 360 | self->saved_variables.reserve(num_saved); | |||
| 361 | for(const auto i : c10::irange(num_saved)) { | |||
| 362 | PyObject *obj = PyTuple_GET_ITEM(self->to_save, i)(((PyTupleObject *)(self->to_save))->ob_item[i]); | |||
| 363 | if (obj == Py_None(&_Py_NoneStruct)) { | |||
| 364 | self->saved_variables.emplace_back(); | |||
| 365 | continue; | |||
| 366 | } else if (THPVariable_Check(obj)) { | |||
| 367 | const auto& tensor = THPVariable_Unpack(obj); | |||
| 368 | bool is_output = tensor.grad_fn().get() == cdata_ptr.get(); | |||
| 369 | self->saved_variables.emplace_back(tensor, is_output); | |||
| 370 | } else { | |||
| 371 | throw torch::TypeError( | |||
| 372 | "save_for_backward can only save variables, but argument %ld is of " | |||
| 373 | "type %s", i, Py_TYPE(obj)(((PyObject*)(obj))->ob_type)->tp_name); | |||
| 374 | } | |||
| 375 | } | |||
| 376 | // Free .to_save | |||
| 377 | Py_CLEAR(self->to_save)do { PyObject *_py_tmp = ((PyObject*)(self->to_save)); if ( _py_tmp != __null) { (self->to_save) = __null; _Py_DECREF( ((PyObject*)(_py_tmp))); } } while (0); | |||
| 378 | } | |||
| 379 | ||||
| 380 | // Mark requires_grad = 0 on non-differentiable variables (as per non_differentiable) | |||
| 381 | static std::unordered_set<at::TensorImpl*> | |||
| 382 | _parse_non_differentiable(THPFunction *self) | |||
| 383 | { | |||
| 384 | std::unordered_set<at::TensorImpl*> set; | |||
| 385 | if (!self->non_differentiable) return set; | |||
| 386 | ||||
| 387 | THPFunction_assert(PyTuple_Check(self->non_differentiable), "autograd "if (!(((((((PyObject*)(self->non_differentiable))->ob_type ))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: non_differentiable attribute is expected to be a " "tuple but is %s", ((((PyObject*)(self->non_differentiable ))->ob_type)->tp_name)); throw python_error(); } | |||
| 388 | "internal error: non_differentiable attribute is expected to be a "if (!(((((((PyObject*)(self->non_differentiable))->ob_type ))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: non_differentiable attribute is expected to be a " "tuple but is %s", ((((PyObject*)(self->non_differentiable ))->ob_type)->tp_name)); throw python_error(); } | |||
| 389 | "tuple but is %s", THPUtils_typename(self->non_differentiable))if (!(((((((PyObject*)(self->non_differentiable))->ob_type ))->tp_flags & ((1UL << 26))) != 0))) { THPUtils_setError ("autograd " "internal error: non_differentiable attribute is expected to be a " "tuple but is %s", ((((PyObject*)(self->non_differentiable ))->ob_type)->tp_name)); throw python_error(); }; | |||
| 390 | Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable)(((PyVarObject*)(((PyTupleObject *)(self->non_differentiable ))))->ob_size); | |||
| 391 | set.reserve(num_nondiff); | |||
| 392 | for(const auto i : c10::irange(num_nondiff)) { | |||
| 393 | PyObject *t = PyTuple_GET_ITEM(self->non_differentiable, i)(((PyTupleObject *)(self->non_differentiable))->ob_item [i]); | |||
| 394 | THPFunction_assert(THPVariable_Check(t), "mark_non_differentiable "if (!(THPVariable_Check(t))) { THPUtils_setError("mark_non_differentiable " "only accepts variable arguments, but got %s", ((((PyObject* )(t))->ob_type)->tp_name)); throw python_error(); } | |||
| 395 | "only accepts variable arguments, but got %s", THPUtils_typename(t))if (!(THPVariable_Check(t))) { THPUtils_setError("mark_non_differentiable " "only accepts variable arguments, but got %s", ((((PyObject* )(t))->ob_type)->tp_name)); throw python_error(); }; | |||
| 396 | set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl()); | |||
| 397 | } | |||
| 398 | Py_CLEAR(self->non_differentiable)do { PyObject *_py_tmp = ((PyObject*)(self->non_differentiable )); if (_py_tmp != __null) { (self->non_differentiable) = __null ; _Py_DECREF(((PyObject*)(_py_tmp))); } } while (0); | |||
| 399 | return set; | |||
| 400 | } | |||
| 401 | ||||
| 402 | struct UnpackedInput { | |||
| 403 | THPObjectPtr input_tuple; | |||
| 404 | variable_list input_vars; | |||
| 405 | }; | |||
| 406 | ||||
| 407 | struct InputFlags { | |||
| 408 | bool is_executable = false; | |||
| 409 | edge_list next_edges; | |||
| 410 | THPObjectPtr needs_input_grad; | |||
| 411 | std::vector<bool> is_variable_input; | |||
| 412 | }; | |||
| 413 | ||||
| 414 | template<bool enforce_variables> | |||
| 415 | std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) { | |||
| 416 | UnpackedInput unpacked; | |||
| 417 | InputFlags flags; | |||
| 418 | ||||
| 419 | auto num_args = PyTuple_GET_SIZE(args)(((PyVarObject*)(((PyTupleObject *)(args))))->ob_size); | |||
| 420 | unpacked.input_tuple = PyTuple_New(num_args); | |||
| 421 | flags.needs_input_grad = PyTuple_New(num_args); | |||
| 422 | for(const auto i : c10::irange(num_args)) { | |||
| 423 | PyObject *arg = PyTuple_GET_ITEM(args, i)(((PyTupleObject *)(args))->ob_item[i]); | |||
| 424 | ||||
| 425 | bool is_variable = THPVariable_Check(arg); | |||
| 426 | flags.is_variable_input.push_back(is_variable); | |||
| 427 | if (!is_variable) { | |||
| 428 | // TODO: remove this code path once Variable and Tensor are merged in Python | |||
| 429 | if (enforce_variables) { | |||
| 430 | THPUtils_setError("expected a Tensor argument, but got %s", | |||
| 431 | THPUtils_typename(arg)((((PyObject*)(arg))->ob_type)->tp_name)); | |||
| 432 | throw python_error(); | |||
| 433 | } | |||
| 434 | Py_INCREF(Py_False)_Py_INCREF(((PyObject*)(((PyObject *) &_Py_FalseStruct))) ); | |||
| 435 | PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False)PyTuple_SetItem(flags.needs_input_grad.get(), i, ((PyObject * ) &_Py_FalseStruct)); | |||
| 436 | } else { | |||
| 437 | const auto& tensor = THPVariable_Unpack(arg); | |||
| 438 | unpacked.input_vars.push_back(tensor); | |||
| 439 | PyObject* needs_grad = tensor.requires_grad() ? Py_True((PyObject *) &_Py_TrueStruct) : Py_False((PyObject *) &_Py_FalseStruct); | |||
| 440 | Py_INCREF(needs_grad)_Py_INCREF(((PyObject*)(needs_grad))); | |||
| 441 | PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad)PyTuple_SetItem(flags.needs_input_grad.get(), i, needs_grad); | |||
| 442 | } | |||
| 443 | Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg))); | |||
| 444 | PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg)PyTuple_SetItem(unpacked.input_tuple.get(), i, arg); | |||
| 445 | } | |||
| 446 | ||||
| 447 | flags.is_executable = GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars); | |||
| 448 | flags.next_edges = (flags.is_executable ? collect_next_edges(unpacked.input_vars) : edge_list()); | |||
| 449 | return std::make_pair(std::move(unpacked), std::move(flags)); | |||
| 450 | } | |||
| 451 | ||||
| 452 | static torch::jit::Node* _trace_pre_record( | |||
| 453 | PyObject* op_obj, | |||
| 454 | PyObject *input_objects, | |||
| 455 | const variable_list& input_vars) { | |||
| 456 | if (!jit::tracer::isTracing()) { | |||
| 457 | return nullptr; | |||
| 458 | } | |||
| 459 | ||||
| 460 | // Save scalar args and the calling convention | |||
| 461 | auto num_args = PyTuple_GET_SIZE(input_objects)(((PyVarObject*)(((PyTupleObject *)(input_objects))))->ob_size ); | |||
| 462 | pyobj_list scalar_args; | |||
| 463 | std::string arg_types; | |||
| 464 | arg_types.reserve(num_args); | |||
| 465 | scalar_args.reserve(num_args); | |||
| 466 | for(const auto i : c10::irange(num_args)) { | |||
| 467 | PyObject *arg_object = PyTuple_GET_ITEM(input_objects, i)(((PyTupleObject *)(input_objects))->ob_item[i]); | |||
| 468 | if (THPVariable_Check(arg_object)) { | |||
| 469 | arg_types.push_back('d'); | |||
| 470 | } else { | |||
| 471 | arg_types.push_back('c'); | |||
| 472 | Py_INCREF(arg_object)_Py_INCREF(((PyObject*)(arg_object))); | |||
| 473 | scalar_args.emplace_back(arg_object); | |||
| 474 | } | |||
| 475 | } | |||
| 476 | ||||
| 477 | Py_INCREF(op_obj)_Py_INCREF(((PyObject*)(op_obj))); | |||
| 478 | auto pyobj = THPObjectPtr(op_obj); | |||
| 479 | return jit::tracer::preRecordPythonTrace( | |||
| 480 | std::move(pyobj), arg_types, input_vars, std::move(scalar_args)); | |||
| 481 | } | |||
| 482 | ||||
| 483 | static void _trace_post_record( | |||
| 484 | torch::jit::Node* node, | |||
| 485 | PyObject* op_obj, | |||
| 486 | const variable_list& input_vars, | |||
| 487 | PyObject *output_objects, | |||
| 488 | bool is_inplace, | |||
| 489 | bool unpack_output) { | |||
| 490 | if (!jit::tracer::isTracing()) { | |||
| 491 | return; | |||
| 492 | } | |||
| 493 | ||||
| 494 | node->i_(jit::attr::inplace, is_inplace); | |||
| 495 | ||||
| 496 | // Isolate C variable ptrs in a vector | |||
| 497 | int num_outputs = PyTuple_GET_SIZE(output_objects)(((PyVarObject*)(((PyTupleObject *)(output_objects))))->ob_size ); | |||
| 498 | auto graph = node->owningGraph(); | |||
| 499 | node->addOutput(); | |||
| 500 | if (!unpack_output) { | |||
| 501 | std::vector<TypePtr> tuple_values(num_outputs, TensorType::get()); | |||
| 502 | TypePtr tuple_type = TupleType::create(std::move(tuple_values)); | |||
| 503 | node->output()->setType(tuple_type); | |||
| 504 | auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node); | |||
| 505 | node = unpacked; | |||
| 506 | } | |||
| 507 | for (const auto i : c10::irange(num_outputs)) { | |||
| 508 | PyObject* obj = PyTuple_GET_ITEM(output_objects, i)(((PyTupleObject *)(output_objects))->ob_item[i]); | |||
| 509 | if (THPVariable_Check(obj)) { | |||
| 510 | Value* value = node->outputs()[i]; | |||
| 511 | const auto& tensor = THPVariable_Unpack(obj); | |||
| 512 | if (tensor.defined()) { | |||
| 513 | value->inferTypeFrom(tensor); | |||
| 514 | jit::tracer::setValueTrace(tensor, value); | |||
| 515 | } | |||
| 516 | } | |||
| 517 | } | |||
| 518 | } | |||
| 519 | ||||
| 520 | PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr<PyNode>& cdata, | |||
| 521 | THPFunction* grad_fn, const UnpackedInput& unpacked, | |||
| 522 | PyObject *inputs, THPObjectPtr&& raw_output, bool is_executable, | |||
| 523 | torch::jit::Node* node) { | |||
| 524 | bool unpack_output = ensure_tuple(raw_output); | |||
| 525 | ||||
| 526 | auto num_outputs = PyTuple_GET_SIZE(raw_output.get())(((PyVarObject*)(((PyTupleObject *)(raw_output.get()))))-> ob_size); | |||
| 527 | ||||
| 528 | THPObjectPtr outputs(PyTuple_New(num_outputs)); | |||
| 529 | if (!outputs) throw python_error(); | |||
| 530 | ||||
| 531 | cdata->clear_input_metadata(); | |||
| 532 | ||||
| 533 | // Record type, device, and size information about inputs | |||
| 534 | if (is_executable) { | |||
| 535 | grad_fn->input_info.clear(); | |||
| 536 | grad_fn->input_info.reserve(unpacked.input_vars.size()); | |||
| 537 | for (auto& var : unpacked.input_vars) { | |||
| 538 | grad_fn->input_info.emplace_back(var); | |||
| 539 | } | |||
| 540 | } | |||
| 541 | ||||
| 542 | bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors); | |||
| 543 | _wrap_outputs(cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); | |||
| 544 | _trace_post_record(node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); | |||
| 545 | if (is_executable) { | |||
| 546 | _save_variables(cdata, grad_fn); | |||
| 547 | } else { | |||
| 548 | // Remove unnecessary attributes | |||
| 549 | Py_XDECREF(grad_fn->to_save)_Py_XDECREF(((PyObject*)(grad_fn->to_save))); | |||
| 550 | grad_fn->to_save = nullptr; | |||
| 551 | Py_XDECREF(grad_fn->non_differentiable)_Py_XDECREF(((PyObject*)(grad_fn->non_differentiable))); | |||
| 552 | grad_fn->non_differentiable = nullptr; | |||
| 553 | } | |||
| 554 | ||||
| 555 | // Unpack the output, unless .forward() returned a tuple | |||
| 556 | if (unpack_output) { | |||
| 557 | PyObject *output = PyTuple_GET_ITEM(outputs.get(), 0)(((PyTupleObject *)(outputs.get()))->ob_item[0]); | |||
| 558 | Py_INCREF(output)_Py_INCREF(((PyObject*)(output))); | |||
| 559 | return output; | |||
| 560 | } | |||
| 561 | ||||
| 562 | return outputs.release(); | |||
| 563 | } | |||
| 564 | ||||
| 565 | PyObject* THPFunction_name(PyObject *self, PyObject* noargs) { | |||
| 566 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 567 | auto cdata = ((THPFunction*)self)->cdata.lock(); | |||
| 568 | TORCH_CHECK(cdata,if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 569 | "Attribute 'name' is invalid for this instance of _C._FunctionBase. "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 570 | "Accessing this attribute directly on an instance of autograd.Function is a legacy "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 571 | "access pattern that is no longer supported. For examples on how to use new-style "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 572 | "autograd functions, see "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 573 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ")if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(573), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'name' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); }; | |||
| 574 | return THPUtils_packString(cdata->name()); | |||
| 575 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 576 | } | |||
| 577 | ||||
| 578 | PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs) | |||
| 579 | { | |||
| 580 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 581 | RECORD_FUNCTION(at::RecordFunction guard(at::RecordScope::FUNCTION); if (guard .isActive()) { if (guard.needsInputs()) { guard.before(((PyTypeObject *)cls)->tp_name, std::vector<c10::IValue>(), at::sequence_number ::peek()); } else { guard.before(((PyTypeObject*)cls)->tp_name , at::sequence_number::peek()); } } | |||
| 582 | ((PyTypeObject*)cls)->tp_name,at::RecordFunction guard(at::RecordScope::FUNCTION); if (guard .isActive()) { if (guard.needsInputs()) { guard.before(((PyTypeObject *)cls)->tp_name, std::vector<c10::IValue>(), at::sequence_number ::peek()); } else { guard.before(((PyTypeObject*)cls)->tp_name , at::sequence_number::peek()); } } | |||
| 583 | std::vector<c10::IValue>(),at::RecordFunction guard(at::RecordScope::FUNCTION); if (guard .isActive()) { if (guard.needsInputs()) { guard.before(((PyTypeObject *)cls)->tp_name, std::vector<c10::IValue>(), at::sequence_number ::peek()); } else { guard.before(((PyTypeObject*)cls)->tp_name , at::sequence_number::peek()); } } | |||
| 584 | at::sequence_number::peek())at::RecordFunction guard(at::RecordScope::FUNCTION); if (guard .isActive()) { if (guard.needsInputs()) { guard.before(((PyTypeObject *)cls)->tp_name, std::vector<c10::IValue>(), at::sequence_number ::peek()); } else { guard.before(((PyTypeObject*)cls)->tp_name , at::sequence_number::peek()); } }; | |||
| 585 | ||||
| 586 | THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls")); | |||
| 587 | if (!backward_cls) return nullptr; | |||
| 588 | THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr)); | |||
| 589 | if (!ctx_obj) return nullptr; | |||
| 590 | THPFunction* ctx = (THPFunction*)ctx_obj.get(); | |||
| 591 | ||||
| 592 | auto cdata = std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode); | |||
| 593 | ctx->cdata = cdata; | |||
| 594 | ||||
| 595 | // Prepare inputs and allocate context (grad fn) | |||
| 596 | auto info_pair = unpack_input<false>(inputs); | |||
| 597 | UnpackedInput& unpacked_input = info_pair.first; | |||
| 598 | InputFlags& input_info = info_pair.second; | |||
| 599 | ||||
| 600 | // Record input nodes if tracing | |||
| 601 | auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars); | |||
| 602 | ||||
| 603 | // Initialize backward function (and ctx) | |||
| 604 | bool is_executable = input_info.is_executable; | |||
| 605 | cdata->set_next_edges(std::move(input_info.next_edges)); | |||
| 606 | ctx->needs_input_grad = input_info.needs_input_grad.release(); | |||
| 607 | ctx->is_variable_input = std::move(input_info.is_variable_input); | |||
| 608 | ||||
| 609 | // Prepend ctx to input_tuple, in preparation for static method call | |||
| 610 | auto num_args = PyTuple_GET_SIZE(inputs)(((PyVarObject*)(((PyTupleObject *)(inputs))))->ob_size); | |||
| 611 | THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); | |||
| 612 | if (!ctx_input_tuple) return nullptr; | |||
| 613 | Py_INCREF(ctx)_Py_INCREF(((PyObject*)(ctx))); | |||
| 614 | PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx)PyTuple_SetItem(ctx_input_tuple.get(), 0, (PyObject*)ctx); | |||
| 615 | for (const auto i : c10::irange(num_args)) { | |||
| 616 | PyObject *arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i)(((PyTupleObject *)(unpacked_input.input_tuple.get()))->ob_item [i]); | |||
| 617 | Py_INCREF(arg)_Py_INCREF(((PyObject*)(arg))); | |||
| 618 | PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg)PyTuple_SetItem(ctx_input_tuple.get(), i + 1, arg); | |||
| 619 | } | |||
| 620 | ||||
| 621 | // Call forward | |||
| 622 | THPObjectPtr tensor_outputs; | |||
| 623 | { | |||
| 624 | AutoGradMode grad_mode(false); | |||
| 625 | THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); | |||
| 626 | if (!forward_fn) return nullptr; | |||
| 627 | tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); | |||
| 628 | if (!tensor_outputs) return nullptr; | |||
| 629 | } | |||
| 630 | ||||
| 631 | return process_outputs(cls, cdata, ctx, unpacked_input, inputs, std::move(tensor_outputs), | |||
| 632 | is_executable, node); | |||
| 633 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 634 | } | |||
| 635 | ||||
| 636 | ||||
| 637 | //////////////////////////////////////////////////////////////////////////////// | |||
| 638 | // Other methods / attributes | |||
| 639 | //////////////////////////////////////////////////////////////////////////////// | |||
| 640 | ||||
| 641 | PyObject* THPFunction__register_hook_dict(PyObject *_self, PyObject *_var) | |||
| 642 | { | |||
| 643 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 644 | THPUtils_assert(THPVariable_Check(_var), "_register_hook_dict expected a Tensor")if ((__builtin_expect((!(THPVariable_Check(_var))), (0)))) { THPUtils_setError ("_register_hook_dict expected a Tensor"); return nullptr; }; | |||
| 645 | THPVariable* var = reinterpret_cast<THPVariable*>(_var); | |||
| 646 | const auto& tensor = THPVariable_Unpack(var); | |||
| 647 | std::unique_ptr<FunctionPreHook> hook(new PyFunctionPreHook( | |||
| 648 | var->backward_hooks, tensor.output_nr())); | |||
| 649 | auto self = (THPFunction*)_self; | |||
| 650 | auto cdata = self->cdata.lock(); | |||
| 651 | TORCH_CHECK(cdata,if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 652 | "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 653 | "Accessing this attribute directly on an instance of autograd.Function is a legacy "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 654 | "access pattern that is no longer supported. For examples on how to use new-style "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 655 | "autograd functions, see "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 656 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ")if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(656), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); }; | |||
| 657 | cdata->add_pre_hook(std::move(hook)); | |||
| 658 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
| 659 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 660 | } | |||
| 661 | ||||
| 662 | PyObject* THPFunction_register_hook(PyObject *_self, PyObject *hook) | |||
| 663 | { | |||
| 664 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 665 | auto self= (THPFunction*)_self; | |||
| 666 | auto cdata = self->cdata.lock(); | |||
| 667 | TORCH_CHECK(cdata,if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 668 | "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 669 | "Accessing this attribute directly on an instance of autograd.Function is a legacy "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 670 | "access pattern that is no longer supported. For examples on how to use new-style "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 671 | "autograd functions, see "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 672 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ")if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(672), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); }; | |||
| 673 | return torch::autograd::registerFunctionHook(*cdata, hook); | |||
| 674 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 675 | } | |||
| 676 | ||||
| 677 | int THPFunction_set_materialize_grads(THPFunction *self, PyObject *value, void *unused) | |||
| 678 | { | |||
| 679 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 680 | if (!PyBool_Check(value)((((PyObject*)(value))->ob_type) == &PyBool_Type)) { | |||
| 681 | THPUtils_invalidArguments(value, nullptr, "set_materialize_grads", 1, "(bool)"); | |||
| 682 | return -1; | |||
| 683 | } | |||
| 684 | self->materialize_grads = (value == Py_True((PyObject *) &_Py_TrueStruct)); | |||
| 685 | return 0; | |||
| 686 | END_HANDLE_TH_ERRORS_RET(-1)} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return -1; } catch (const c10::IndexError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_IndexError , torch::processErrorMsg(msg)); return -1; } catch (const c10 ::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return -1; } catch (const c10 ::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return -1; } catch (const c10 ::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return -1; } catch (const c10 ::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return -1; } catch (torch::PyTorchError & e) { auto msg = torch::processErrorMsg(e.what()); PyErr_SetString (e.python_type(), msg); return -1; } catch (const std::exception & e) { auto msg = torch::processErrorMsg(e.what()); PyErr_SetString (PyExc_RuntimeError, msg); return -1; } | |||
| 687 | } | |||
| 688 | ||||
| 689 | static PyObject *unpack_saved_variables( | |||
| 690 | THPFunction *self, | |||
| 691 | const std::function<PyObject*(const Variable&)>& unpack_fn) | |||
| 692 | { | |||
| 693 | THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE)if ((__builtin_expect((!(!self->has_freed_buffers)), (0))) ) { THPUtils_setError(ERR_BACKWARD_TWICE); return nullptr; }; | |||
| 694 | auto& saved_variables = self->saved_variables; | |||
| 695 | if (saved_variables.empty()) | |||
| 696 | return PyTuple_New(0); | |||
| 697 | ||||
| 698 | int num_saved = saved_variables.size(); | |||
| 699 | THPObjectPtr saved(PyTuple_New(num_saved)); | |||
| 700 | if (!saved) | |||
| 701 | return nullptr; | |||
| 702 | auto saved_for = self->cdata.lock(); | |||
| 703 | // This is really a true assert, because we've already tested for the | |||
| 704 | // self->has_freed_buffers case at the beginning of this function: | |||
| 705 | // buffers are freed when PyNode dies; if the buffers are not freed, | |||
| 706 | // PyNode must be live. (Note that the buffers could be freed | |||
| 707 | // even though the PyNode is live, but that doesn't matter here | |||
| 708 | // because we will never hit this line of code if the buffers are freed-- | |||
| 709 | // and in any case saved_for will be non-NULL.) | |||
| 710 | TORCH_INTERNAL_ASSERT(saved_for)if ((__builtin_expect(static_cast<bool>(!(saved_for)), 0 ))) { ::c10::detail::torchInternalAssertFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(710), "saved_for" "INTERNAL ASSERT FAILED at " "\"../torch/csrc/autograd/python_function.cpp\"" ":" "710" ", please report a bug to PyTorch. " , c10::str()); }; | |||
| 711 | for(const auto i : c10::irange(num_saved)) { | |||
| 712 | auto unpacked_var = saved_variables[i].unpack(saved_for); | |||
| 713 | THPObjectPtr value; | |||
| 714 | if (!unpacked_var.defined()) { | |||
| 715 | Py_INCREF(Py_None)_Py_INCREF(((PyObject*)((&_Py_NoneStruct)))); | |||
| 716 | value = Py_None(&_Py_NoneStruct); | |||
| 717 | } else { | |||
| 718 | value = unpack_fn(unpacked_var); | |||
| 719 | } | |||
| 720 | PyTuple_SET_ITEM(saved.get(), i, value.release())PyTuple_SetItem(saved.get(), i, value.release()); | |||
| 721 | } | |||
| 722 | return saved.release(); | |||
| 723 | } | |||
| 724 | ||||
| 725 | PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused) | |||
| 726 | { | |||
| 727 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 728 | return unpack_saved_variables(self, [](const Variable& var) { | |||
| 729 | return THPVariable_Wrap(var); | |||
| 730 | }); | |||
| 731 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 732 | } | |||
| 733 | ||||
| 734 | PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused) | |||
| 735 | { | |||
| 736 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 737 | auto r = PyErr_WarnEx(PyExc_DeprecationWarning, | |||
| 738 | "'saved_variables' is deprecated; use 'saved_tensors'", 0); | |||
| 739 | if (r != 0) throw python_error(); | |||
| 740 | return unpack_saved_variables(self, [](const Variable& var) { | |||
| 741 | return THPVariable_Wrap(var); | |||
| 742 | }); | |||
| 743 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 744 | } | |||
| 745 | ||||
| 746 | PyObject *THPFunction_next_functions(THPFunction *self, void *_unused) | |||
| 747 | { | |||
| 748 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 749 | auto cdata = self->cdata.lock(); | |||
| 750 | TORCH_CHECK(cdata,if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 751 | "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 752 | "Accessing this attribute directly on an instance of autograd.Function is a legacy "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 753 | "access pattern that is no longer supported. For examples on how to use new-style "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 754 | "autograd functions, see "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); } | |||
| 755 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ")if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(755), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " "Accessing this attribute directly on an instance of autograd.Function is a legacy " "access pattern that is no longer supported. For examples on how to use new-style " "autograd functions, see " "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ))); }; | |||
| 756 | const auto num_outputs = cdata->num_outputs(); | |||
| 757 | THPObjectPtr result(PyTuple_New(num_outputs)); | |||
| 758 | if (!result) | |||
| 759 | return nullptr; | |||
| 760 | for (const auto i : c10::irange(num_outputs)) { | |||
| 761 | THPObjectPtr fn_tuple(PyTuple_New(2)); | |||
| 762 | if (!fn_tuple) return nullptr; | |||
| 763 | const auto& edge = cdata->next_edge(i); | |||
| 764 | PyObject* fn = functionToPyObject(edge.function); | |||
| 765 | if (!fn) return nullptr; | |||
| 766 | PyTuple_SET_ITEM(fn_tuple.get(), 0, fn)PyTuple_SetItem(fn_tuple.get(), 0, fn); | |||
| 767 | PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr))PyTuple_SetItem(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr )); | |||
| 768 | PyTuple_SET_ITEM(result.get(), i, fn_tuple.release())PyTuple_SetItem(result.get(), i, fn_tuple.release()); | |||
| 769 | } | |||
| 770 | return result.release(); | |||
| 771 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 772 | } | |||
| 773 | ||||
| 774 | PyObject *THPFunction_metadata(THPFunction *self, void *_unused) | |||
| 775 | { | |||
| 776 | HANDLE_TH_ERRORStry { torch::PyWarningHandler __enforce_warning_buffer; try { | |||
| 777 | auto cdata = self->cdata.lock(); | |||
| 778 | // The correct way to solve this problem is to stop exposing grad_fn | |||
| 779 | // of PyFunctions as THPFunction; instead, we should use THPCppFunction | |||
| 780 | // like everyone else. But this is a BC-breaking change as it would | |||
| 781 | // mean that you no longer get the property that grad_fn is a subclass | |||
| 782 | // of the autograd function class that you defined in the custom case, | |||
| 783 | // so I didn't fix it here. | |||
| 784 | TORCH_CHECK(cdata,if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 785 | "You attempted to access the anomaly metadata of a custom autograd function "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 786 | "but the underlying PyNode has already been deallocated. The most likely "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 787 | "reason this occurred is because you assigned x.grad_fn to a local variable "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 788 | "and then let the original variable get deallocated. Don't do that! If "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 789 | "you really have no way of restructuring your code so this is the case, "if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); } | |||
| 790 | "please file an issue reporting that you are affected by this.")if ((__builtin_expect(static_cast<bool>(!(cdata)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/autograd/python_function.cpp" , static_cast<uint32_t>(790), (::c10::detail::torchCheckMsgImpl ( "Expected " "cdata" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "You attempted to access the anomaly metadata of a custom autograd function " "but the underlying PyNode has already been deallocated. The most likely " "reason this occurred is because you assigned x.grad_fn to a local variable " "and then let the original variable get deallocated. Don't do that! If " "you really have no way of restructuring your code so this is the case, " "please file an issue reporting that you are affected by this." ))); }; | |||
| 791 | auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict(); | |||
| 792 | ||||
| 793 | Py_INCREF(metadata)_Py_INCREF(((PyObject*)(metadata))); | |||
| 794 | return metadata; | |||
| 795 | END_HANDLE_TH_ERRORS} catch(...) { __enforce_warning_buffer.set_in_exception(); throw ; } } catch (python_error & e) { e.restore(); return nullptr ; } catch (const c10::IndexError& e) { auto msg = torch:: get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace (); PyErr_SetString(PyExc_IndexError, torch::processErrorMsg( msg)); return nullptr; } catch (const c10::ValueError& e) { auto msg = torch::get_cpp_stacktraces_enabled() ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_ValueError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::TypeError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_TypeError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::NotImplementedError& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_NotImplementedError , torch::processErrorMsg(msg)); return nullptr; } catch (const c10::Error& e) { auto msg = torch::get_cpp_stacktraces_enabled () ? e.what() : e.what_without_backtrace(); PyErr_SetString(PyExc_RuntimeError , torch::processErrorMsg(msg)); return nullptr; } catch (torch ::PyTorchError & e) { auto msg = torch::processErrorMsg(e .what()); PyErr_SetString(e.python_type(), msg); return nullptr ; } catch (const std::exception& e) { auto msg = torch::processErrorMsg (e.what()); PyErr_SetString(PyExc_RuntimeError, msg); return nullptr ; } | |||
| 796 | } | |||
| 797 | ||||
| 798 | typedef PyObject *(*getter)(PyObject *, void *); | |||
| 799 | typedef int (*setter)(PyObject *, PyObject *, void *); | |||
| 800 | ||||
| 801 | namespace { | |||
| 802 | ||||
| 803 | template<PyObject* THPFunction::*ptr> | |||
| 804 | PyObject* getObject(PyObject* obj, void* _unused) { | |||
| 805 | auto self = (THPFunction*)obj; | |||
| 806 | PyObject* value = self->*ptr; | |||
| 807 | if (!value) { | |||
| 808 | Py_RETURN_NONEreturn _Py_INCREF(((PyObject*)((&_Py_NoneStruct)))), (& _Py_NoneStruct); | |||
| 809 | } | |||
| 810 | Py_INCREF(value)_Py_INCREF(((PyObject*)(value))); | |||
| 811 | return value; | |||
| 812 | } | |||
| 813 | ||||
| 814 | template<PyObject* THPFunction::*ptr> | |||
| 815 | int setObject(PyObject* obj, PyObject* value, void* _unused) { | |||
| 816 | auto self = (THPFunction*)obj; | |||
| 817 | if (value == Py_None(&_Py_NoneStruct)) { | |||
| 818 | value = nullptr; | |||
| 819 | } | |||
| 820 | Py_XDECREF((self->*ptr))_Py_XDECREF(((PyObject*)((self->*ptr)))); | |||
| 821 | Py_XINCREF(value)_Py_XINCREF(((PyObject*)(value))); | |||
| 822 | self->*ptr = value; | |||
| 823 | return 0; | |||
| 824 | } | |||
| 825 | ||||
| 826 | template<typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)> | |||
| 827 | PyObject* getMember(PyObject* obj, void* _unused) { | |||
| 828 | auto self = (THPFunction*)obj; | |||
| 829 | return Convert(self->*ptr); | |||
| 830 | } | |||
| 831 | ||||
| 832 | template<typename M, M autograd::Node::*ptr, PyObject* (*Convert)(long)> | |||
| 833 | PyObject* getImplMember(PyObject* obj, void* _unused) { | |||
| 834 | auto self = (THPFunction*)obj; | |||
| 835 | return Convert(self->cdata.*ptr); | |||
| 836 | } | |||
| 837 | ||||
| 838 | PyObject* getRequiresGrad(PyObject* obj, void* _unused) { | |||
| 839 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
| 840 | } | |||
| 841 | ||||
| 842 | } | |||
| 843 | ||||
| 844 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) | |||
| 845 | static struct PyGetSetDef THPFunction_properties[] = { | |||
| 846 | {"saved_tensors", (getter)THPFunction_saved_tensors, nullptr, nullptr, nullptr}, | |||
| 847 | {"saved_variables", (getter)THPFunction_saved_variables, nullptr, nullptr, nullptr}, | |||
| 848 | {"next_functions", (getter)THPFunction_next_functions, nullptr, nullptr, nullptr}, | |||
| 849 | {"to_save", &getObject<&THPFunction::to_save>, &setObject<&THPFunction::to_save>, nullptr, nullptr}, | |||
| 850 | {"non_differentiable", &getObject<&THPFunction::non_differentiable>, &setObject<&THPFunction::non_differentiable>, nullptr, nullptr}, | |||
| 851 | {"dirty_tensors", &getObject<&THPFunction::dirty_tensors>, &setObject<&THPFunction::dirty_tensors>, nullptr, nullptr}, | |||
| 852 | {"needs_input_grad", &getObject<&THPFunction::needs_input_grad>, nullptr, nullptr, nullptr}, | |||
| 853 | {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, | |||
| 854 | {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, | |||
| 855 | {"materialize_grads", nullptr, (setter)THPFunction_set_materialize_grads, nullptr, nullptr}, | |||
| 856 | {nullptr} | |||
| 857 | }; | |||
| 858 | ||||
| 859 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) | |||
| 860 | static struct PyMethodDef THPFunction_methods[] = { | |||
| 861 | {(char*)"name", THPFunction_name, METH_NOARGS0x0004, nullptr}, | |||
| 862 | {(char*)"apply", THPFunction_apply, METH_CLASS0x0010 | METH_VARARGS0x0001, nullptr}, | |||
| 863 | {(char*)"_register_hook_dict", THPFunction__register_hook_dict, METH_O0x0008, nullptr}, | |||
| 864 | {(char*)"register_hook", THPFunction_register_hook, METH_O0x0008, nullptr}, | |||
| 865 | {nullptr} | |||
| 866 | }; | |||
| 867 | ||||
| 868 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) | |||
| 869 | PyTypeObject THPFunctionType = { | |||
| 870 | PyVarObject_HEAD_INIT(nullptr, 0){ { 1, nullptr }, 0 }, | |||
| 871 | "torch._C._FunctionBase", /* tp_name */ | |||
| 872 | sizeof(THPFunction), /* tp_basicsize */ | |||
| 873 | 0, /* tp_itemsize */ | |||
| 874 | (destructor)THPFunction_dealloc, /* tp_dealloc */ | |||
| 875 | // NOLINTNEXTLINE(modernize-use-nullptr) | |||
| 876 | 0, /* tp_vectorcall_offset */ | |||
| 877 | nullptr, /* tp_getattr */ | |||
| 878 | nullptr, /* tp_setattr */ | |||
| 879 | nullptr, /* tp_reserved */ | |||
| 880 | nullptr, /* tp_repr */ | |||
| 881 | nullptr, /* tp_as_number */ | |||
| 882 | nullptr, /* tp_as_sequence */ | |||
| 883 | nullptr, /* tp_as_mapping */ | |||
| 884 | nullptr, /* tp_hash */ | |||
| 885 | nullptr, /* tp_call */ | |||
| 886 | nullptr, /* tp_str */ | |||
| 887 | nullptr, /* tp_getattro */ | |||
| 888 | nullptr, /* tp_setattro */ | |||
| 889 | nullptr, /* tp_as_buffer */ | |||
| 890 | Py_TPFLAGS_DEFAULT( 0 | (1UL << 18) | 0) | Py_TPFLAGS_BASETYPE(1UL << 10) | Py_TPFLAGS_HAVE_GC(1UL << 14), /* tp_flags */ | |||
| 891 | nullptr, /* tp_doc */ | |||
| 892 | (traverseproc)THPFunction_traverse, /* tp_traverse */ | |||
| 893 | (inquiry)THPFunction_clear, /* tp_clear */ | |||
| 894 | nullptr, /* tp_richcompare */ | |||
| 895 | 0, /* tp_weaklistoffset */ | |||
| 896 | nullptr, /* tp_iter */ | |||
| 897 | nullptr, /* tp_iternext */ | |||
| 898 | THPFunction_methods, /* tp_methods */ | |||
| 899 | nullptr, /* tp_members */ | |||
| 900 | THPFunction_properties, /* tp_getset */ | |||
| 901 | nullptr, /* tp_base */ | |||
| 902 | nullptr, /* tp_dict */ | |||
| 903 | nullptr, /* tp_descr_get */ | |||
| 904 | nullptr, /* tp_descr_set */ | |||
| 905 | 0, /* tp_dictoffset */ | |||
| 906 | nullptr, /* tp_init */ | |||
| 907 | nullptr, /* tp_alloc */ | |||
| 908 | THPFunction_new /* tp_new */ | |||
| 909 | }; | |||
| 910 | ||||
| 911 | bool THPFunction_initModule(PyObject *module) | |||
| 912 | { | |||
| 913 | if (PyType_Ready(&THPFunctionType) < 0) | |||
| 914 | return false; | |||
| 915 | Py_INCREF(&THPFunctionType)_Py_INCREF(((PyObject*)(&THPFunctionType))); | |||
| 916 | PyModule_AddObject(module, "_FunctionBase", (PyObject *)&THPFunctionType); | |||
| 917 | return true; | |||
| 918 | } |
| 1 | #ifndef PyObject_CallObject |
| 2 | struct _object; |
| 3 | typedef struct _object PyObject; |
| 4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
| 5 | PyObject* PyObject_CallObject(PyObject *callable, PyObject *args) { |
| 6 | return clang_analyzer_PyObject_New_Reference(); |
| 7 | } |
| 8 | #else |
| 9 | #warning "API PyObject_CallObject is defined as a macro." |
| 10 | #endif |