File: | build/../torch/csrc/distributed/rpc/init.cpp |
Warning: | line 36, column 20 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | #include <torch/csrc/python_headers.h> | |||
2 | ||||
3 | #include <torch/csrc/distributed/rpc/process_group_agent.h> | |||
4 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> | |||
5 | #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> | |||
6 | #include <torch/csrc/distributed/rpc/py_rref.h> | |||
7 | #include <torch/csrc/distributed/rpc/python_functions.h> | |||
8 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> | |||
9 | #include <torch/csrc/distributed/rpc/request_callback_impl.h> | |||
10 | #include <torch/csrc/distributed/rpc/rpc_agent.h> | |||
11 | #include <torch/csrc/distributed/rpc/rref_context.h> | |||
12 | #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> | |||
13 | #include <torch/csrc/distributed/rpc/torchscript_functions.h> | |||
14 | #include <torch/csrc/distributed/rpc/types.h> | |||
15 | #include <torch/csrc/jit/python/pybind_utils.h> | |||
16 | #include <torch/csrc/utils/object_ptr.h> | |||
17 | #include <torch/csrc/utils/pybind.h> | |||
18 | #include <torch/types.h> | |||
19 | ||||
20 | #include <pybind11/chrono.h> | |||
21 | #include <pybind11/operators.h> | |||
22 | ||||
23 | namespace torch { | |||
24 | namespace distributed { | |||
25 | namespace rpc { | |||
26 | ||||
27 | namespace { | |||
28 | ||||
29 | constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000); | |||
30 | ||||
31 | template <typename T> | |||
32 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; | |||
33 | ||||
34 | PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { | |||
35 | auto rpc_module = | |||
36 | THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc")); | |||
| ||||
| ||||
37 | if (!rpc_module) { | |||
38 | throw python_error(); | |||
39 | } | |||
40 | ||||
41 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); | |||
42 | if (!torch_C_module) { | |||
43 | throw python_error(); | |||
44 | } | |||
45 | ||||
46 | auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); | |||
47 | auto m = | |||
48 | torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); | |||
49 | ||||
50 | auto module = py::handle(m).cast<py::module>(); | |||
51 | ||||
52 | auto rpcBackendOptions = | |||
53 | shared_ptr_class_<RpcBackendOptions>( | |||
54 | module, | |||
55 | "RpcBackendOptions", | |||
56 | R"(An abstract structure encapsulating the options passed into the RPC | |||
57 | backend. An instance of this class can be passed in to | |||
58 | :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC | |||
59 | with specific configurations, such as the RPC timeout and | |||
60 | ``init_method`` to be used. )") | |||
61 | .def(py::init<>()) | |||
62 | .def( | |||
63 | py::init<float, std::string>(), | |||
64 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
65 | py::arg("init_method") = kDefaultInitMethod) | |||
66 | .def_readwrite( | |||
67 | "rpc_timeout", | |||
68 | &RpcBackendOptions::rpcTimeoutSeconds, | |||
69 | R"(A float indicating the timeout to use for all | |||
70 | RPCs. If an RPC does not complete in this timeframe, it will | |||
71 | complete with an exception indicating that it has timed out.)") | |||
72 | .def_readwrite( | |||
73 | "init_method", | |||
74 | &RpcBackendOptions::initMethod, | |||
75 | R"(URL specifying how to initialize the process group. | |||
76 | Default is ``env://``)"); | |||
77 | ||||
78 | // The following C++ constants need to be cast so they can be used from | |||
79 | // python. | |||
80 | module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds); | |||
81 | module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout); | |||
82 | module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod); | |||
83 | ||||
84 | auto workerInfo = | |||
85 | shared_ptr_class_<WorkerInfo>( | |||
86 | module, | |||
87 | "WorkerInfo", | |||
88 | R"(A structure that encapsulates information of a worker in the system. | |||
89 | Contains the name and ID of the worker. This class is not meant to | |||
90 | be constructed directly, rather, an instance can be retrieved | |||
91 | through :meth:`~torch.distributed.rpc.get_worker_info` and the | |||
92 | result can be passed in to functions such as | |||
93 | :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`, | |||
94 | :meth:`~torch.distributed.rpc.remote` to avoid copying a string on | |||
95 | every invocation.)") | |||
96 | .def( | |||
97 | py::init<std::string, worker_id_t>(), | |||
98 | py::arg("name"), | |||
99 | py::arg("id")) | |||
100 | .def_readonly( | |||
101 | "name", &WorkerInfo::name_, R"(The name of the worker.)") | |||
102 | .def_readonly( | |||
103 | "id", | |||
104 | &WorkerInfo::id_, | |||
105 | R"(Globally unique id to identify the worker.)") | |||
106 | .def("__eq__", &WorkerInfo::operator==, py::is_operator()) | |||
107 | // pybind11 suggests the syntax .def(hash(py::self)), with the | |||
108 | // unqualified "hash" function call. However the | |||
109 | // argument-dependent lookup for the function "hash" doesn't get | |||
110 | // triggered in this context because it conflicts with the struct | |||
111 | // c10::hash, so we need to use the qualified name | |||
112 | // py::detail::hash, which unfortunately is in a detail namespace. | |||
113 | .def(py::detail::hash(py::self)) // NOLINT | |||
114 | .def("__repr__", [](const WorkerInfo& workerInfo) { | |||
115 | std::ostringstream os; | |||
116 | os << workerInfo; | |||
117 | return os.str(); | |||
118 | }); | |||
119 | ||||
120 | auto rpcAgent = | |||
121 | shared_ptr_class_<RpcAgent>(module, "RpcAgent") | |||
122 | .def( | |||
123 | "join", | |||
124 | &RpcAgent::join, | |||
125 | py::call_guard<py::gil_scoped_release>(), | |||
126 | py::arg("shutdown") = false) | |||
127 | .def( | |||
128 | "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>()) | |||
129 | .def( | |||
130 | "shutdown", | |||
131 | &RpcAgent::shutdown, | |||
132 | py::call_guard<py::gil_scoped_release>()) | |||
133 | .def( | |||
134 | "get_worker_info", | |||
135 | (const WorkerInfo& (RpcAgent::*)(void) const) & | |||
136 | RpcAgent::getWorkerInfo, | |||
137 | py::call_guard<py::gil_scoped_release>()) | |||
138 | .def( | |||
139 | "get_worker_info", | |||
140 | (const WorkerInfo& (RpcAgent::*)(const std::string&) const) & | |||
141 | RpcAgent::getWorkerInfo, | |||
142 | py::call_guard<py::gil_scoped_release>()) | |||
143 | .def( | |||
144 | "get_worker_infos", | |||
145 | &RpcAgent::getWorkerInfos, | |||
146 | py::call_guard<py::gil_scoped_release>()) | |||
147 | .def( | |||
148 | "_get_device_map", | |||
149 | &RpcAgent::getDeviceMap, | |||
150 | py::call_guard<py::gil_scoped_release>()) | |||
151 | .def( | |||
152 | "get_debug_info", | |||
153 | &RpcAgent::getDebugInfo, | |||
154 | py::call_guard<py::gil_scoped_release>()) | |||
155 | .def( | |||
156 | "get_metrics", | |||
157 | &RpcAgent::getMetrics, | |||
158 | py::call_guard<py::gil_scoped_release>()); | |||
159 | ||||
160 | auto pyRRef = | |||
161 | shared_ptr_class_<PyRRef>(module, "PyRRef", R"( | |||
162 | A class encapsulating a reference to a value of some type on a remote | |||
163 | worker. This handle will keep the referenced remote value alive on the | |||
164 | worker. A ``UserRRef`` will be deleted when 1) no references to it in | |||
165 | both the application code and in the local RRef context, or 2) the | |||
166 | application has called a graceful shutdown. Invoking methods on a | |||
167 | deleted RRef leads to undefined behaviors. RRef implementation only | |||
168 | offers best-effort error detection, and applications should not use | |||
169 | ``UserRRefs`` after ``rpc.shutdown()``. | |||
170 | ||||
171 | .. warning:: | |||
172 | RRefs can only be serialized and deserialized by the RPC module. | |||
173 | Serializing and deserializing RRefs without RPC (e.g., Python | |||
174 | pickle, torch :meth:`~torch.save` / :meth:`~torch.load`, | |||
175 | JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will | |||
176 | lead to errors. | |||
177 | ||||
178 | Args: | |||
179 | value (object): The value to be wrapped by this RRef. | |||
180 | type_hint (Type, optional): Python type that should be passed to | |||
181 | ``TorchScript`` compiler as type hint for ``value``. | |||
182 | ||||
183 | Example:: | |||
184 | Following examples skip RPC initialization and shutdown code | |||
185 | for simplicity. Refer to RPC docs for those details. | |||
186 | ||||
187 | 1. Create an RRef using rpc.remote | |||
188 | ||||
189 | >>> import torch | |||
190 | >>> import torch.distributed.rpc as rpc | |||
191 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) | |||
192 | >>> # get a copy of value from the RRef | |||
193 | >>> x = rref.to_here() | |||
194 | ||||
195 | 2. Create an RRef from a local object | |||
196 | ||||
197 | >>> import torch | |||
198 | >>> from torch.distributed.rpc import RRef | |||
199 | >>> x = torch.zeros(2, 2) | |||
200 | >>> rref = RRef(x) | |||
201 | ||||
202 | 3. Share an RRef with other workers | |||
203 | ||||
204 | >>> # On both worker0 and worker1: | |||
205 | >>> def f(rref): | |||
206 | >>> return rref.to_here() + 1 | |||
207 | ||||
208 | >>> # On worker0: | |||
209 | >>> import torch | |||
210 | >>> import torch.distributed.rpc as rpc | |||
211 | >>> from torch.distributed.rpc import RRef | |||
212 | >>> rref = RRef(torch.zeros(2, 2)) | |||
213 | >>> # the following RPC shares the rref with worker1, reference | |||
214 | >>> # count is automatically updated. | |||
215 | >>> rpc.rpc_sync("worker1", f, args=(rref,)) | |||
216 | )") | |||
217 | .def( | |||
218 | py::init<const py::object&, const py::object&>(), | |||
219 | py::arg("value"), | |||
220 | py::arg("type_hint") = py::none()) | |||
221 | .def( | |||
222 | // not releasing GIL here to avoid context switch on getters | |||
223 | "is_owner", | |||
224 | &PyRRef::isOwner, | |||
225 | R"( | |||
226 | Returns whether or not the current node is the owner of this | |||
227 | ``RRef``. | |||
228 | )") | |||
229 | .def( | |||
230 | "confirmed_by_owner", | |||
231 | &PyRRef::confirmedByOwner, | |||
232 | R"( | |||
233 | Returns whether this ``RRef`` has been confirmed by the owner. | |||
234 | ``OwnerRRef`` always returns true, while ``UserRRef`` only | |||
235 | returns true when the owner knowns about this ``UserRRef``. | |||
236 | )") | |||
237 | .def( | |||
238 | // not releasing GIL here to avoid context switch on getters | |||
239 | "owner", | |||
240 | &PyRRef::owner, | |||
241 | R"( | |||
242 | Returns worker information of the node that owns this ``RRef``. | |||
243 | )") | |||
244 | .def( | |||
245 | // not releasing GIL here to avoid context switch on getters | |||
246 | "owner_name", | |||
247 | &PyRRef::ownerName, | |||
248 | R"( | |||
249 | Returns worker name of the node that owns this ``RRef``. | |||
250 | )") | |||
251 | .def( | |||
252 | "to_here", | |||
253 | &PyRRef::toHere, | |||
254 | py::arg("timeout") = py::cast(kUnsetRpcTimeout), | |||
255 | py::call_guard<py::gil_scoped_release>(), | |||
256 | R"( | |||
257 | Blocking call that copies the value of the RRef from the owner | |||
258 | to the local node and returns it. If the current node is the | |||
259 | owner, returns a reference to the local value. | |||
260 | ||||
261 | Args: | |||
262 | timeout (float, optional): Timeout for ``to_here``. If | |||
263 | the call does not complete within this timeframe, an | |||
264 | exception indicating so will be raised. If this | |||
265 | argument is not provided, the default RPC timeout | |||
266 | (60s) will be used. | |||
267 | )") | |||
268 | .def( | |||
269 | "local_value", | |||
270 | &PyRRef::localValue, | |||
271 | py::call_guard<py::gil_scoped_release>(), | |||
272 | R"( | |||
273 | If the current node is the owner, returns a reference to the | |||
274 | local value. Otherwise, throws an exception. | |||
275 | )") | |||
276 | .def( | |||
277 | "rpc_sync", | |||
278 | [](const PyRRef& self, float timeoutSeconds) { | |||
279 | return self.createRRefProxy( | |||
280 | RRefProxyType::RPC_SYNC, timeoutSeconds); | |||
281 | }, | |||
282 | py::arg("timeout") = kUnsetRpcTimeout, | |||
283 | py::call_guard<py::gil_scoped_release>(), | |||
284 | R"( | |||
285 | Create a helper proxy to easily launch an ``rpc_sync`` using | |||
286 | the owner of the RRef as the destination to run functions on | |||
287 | the object referenced by this RRef. More specifically, | |||
288 | ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as | |||
289 | the following: | |||
290 | ||||
291 | >>> def run(rref, func_name, args, kwargs): | |||
292 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
293 | >>> | |||
294 | >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
295 | ||||
296 | Args: | |||
297 | timeout (float, optional): Timeout for ``rref.rpc_sync()``. | |||
298 | If the call does not complete within this timeframe, an | |||
299 | exception indicating so will be raised. If this argument | |||
300 | is not provided, the default RPC timeout will be used. | |||
301 | ||||
302 | Example:: | |||
303 | >>> from torch.distributed import rpc | |||
304 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
305 | >>> rref.rpc_sync().size() # returns torch.Size([2, 2]) | |||
306 | >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]]) | |||
307 | )") | |||
308 | .def( | |||
309 | "rpc_async", | |||
310 | [](const PyRRef& self, float timeoutSeconds) { | |||
311 | return self.createRRefProxy( | |||
312 | RRefProxyType::RPC_ASYNC, timeoutSeconds); | |||
313 | }, | |||
314 | py::arg("timeout") = kUnsetRpcTimeout, | |||
315 | py::call_guard<py::gil_scoped_release>(), | |||
316 | R"( | |||
317 | Create a helper proxy to easily launch an ``rpc_async`` using | |||
318 | the owner of the RRef as the destination to run functions on | |||
319 | the object referenced by this RRef. More specifically, | |||
320 | ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as | |||
321 | the following: | |||
322 | ||||
323 | >>> def run(rref, func_name, args, kwargs): | |||
324 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
325 | >>> | |||
326 | >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
327 | ||||
328 | Args: | |||
329 | timeout (float, optional): Timeout for ``rref.rpc_async()``. | |||
330 | If the call does not complete within this timeframe, an | |||
331 | exception indicating so will be raised. If this argument | |||
332 | is not provided, the default RPC timeout will be used. | |||
333 | ||||
334 | Example:: | |||
335 | >>> from torch.distributed import rpc | |||
336 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
337 | >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2]) | |||
338 | >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]]) | |||
339 | )") | |||
340 | .def( | |||
341 | "remote", | |||
342 | [](const PyRRef& self, float timeoutSeconds) { | |||
343 | return self.createRRefProxy( | |||
344 | RRefProxyType::REMOTE, timeoutSeconds); | |||
345 | }, | |||
346 | py::arg("timeout") = kUnsetRpcTimeout, | |||
347 | py::call_guard<py::gil_scoped_release>(), | |||
348 | R"( | |||
349 | Create a helper proxy to easily launch a ``remote`` using | |||
350 | the owner of the RRef as the destination to run functions on | |||
351 | the object referenced by this RRef. More specifically, | |||
352 | ``rref.remote().func_name(*args, **kwargs)`` is the same as | |||
353 | the following: | |||
354 | ||||
355 | >>> def run(rref, func_name, args, kwargs): | |||
356 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
357 | >>> | |||
358 | >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
359 | ||||
360 | Args: | |||
361 | timeout (float, optional): Timeout for ``rref.remote()``. If | |||
362 | the creation of this :class:`~torch.distributed.rpc.RRef` | |||
363 | is not successfully completed within the timeout, then the | |||
364 | next time there is an attempt to use the RRef | |||
365 | (such as ``to_here``), a timeout will be raised. If not | |||
366 | provided, the default RPC timeout will be used. Please see | |||
367 | ``rpc.remote()`` for specific timeout semantics for | |||
368 | :class:`~torch.distributed.rpc.RRef`. | |||
369 | ||||
370 | Example:: | |||
371 | >>> from torch.distributed import rpc | |||
372 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
373 | >>> rref.remote().size().to_here() # returns torch.Size([2, 2]) | |||
374 | >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]]) | |||
375 | )") | |||
376 | .def( | |||
377 | py::pickle( | |||
378 | /* __getstate__ */ | |||
379 | [](const PyRRef& /* unused */) { | |||
380 | TORCH_CHECK(if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
381 | false,if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
382 | "Can not pickle rref in python pickler, rref can only be "if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
383 | "pickled when using RPC")if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); }; | |||
384 | // Note that this return has no meaning since we always | |||
385 | // throw, it's only here to satisfy Pybind API's | |||
386 | // requirement. | |||
387 | return py::make_tuple(); | |||
388 | }, | |||
389 | /* __setstate__ */ | |||
390 | [](py::tuple /* unused */) { // NOLINT | |||
391 | TORCH_CHECK(if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
392 | false,if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
393 | "Can not unpickle rref in python pickler, rref can only be "if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
394 | "unpickled when using RPC")if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); }; | |||
395 | // Note that this return has no meaning since we always | |||
396 | // throw, it's only here to satisfy PyBind's API | |||
397 | // requirement. | |||
398 | return PyRRef( | |||
399 | py::cast<py::none>(Py_None(&_Py_NoneStruct)), | |||
400 | py::cast<py::none>(Py_None(&_Py_NoneStruct))); | |||
401 | }), | |||
402 | py::call_guard<py::gil_scoped_release>()) | |||
403 | .def( | |||
404 | "_serialize", | |||
405 | &PyRRef::pickle, | |||
406 | py::call_guard<py::gil_scoped_release>()) | |||
407 | .def_static( | |||
408 | "_deserialize", | |||
409 | &PyRRef::unpickle, | |||
410 | py::call_guard<py::gil_scoped_release>()) | |||
411 | .def( | |||
412 | "_get_type", | |||
413 | // Intentionally not releasing GIL, as most accesses just | |||
414 | // retrieve cached type py::object | |||
415 | &PyRRef::getRRefType, | |||
416 | py::arg("timeout") = kUnsetRpcTimeout, | |||
417 | py::arg("blocking") = true, | |||
418 | R"( | |||
419 | If ``blocking=True``, returns the type of the data object | |||
420 | referenced by this ``RRef``. On the owner, this is same as | |||
421 | ``type(rref.local_value())``. Otherwise, returns a future to | |||
422 | this result. On a user, this will trigger an RPC to fetch the | |||
423 | ``type`` object from the owner. After this function is run | |||
424 | once, the ``type`` object is cached by the ``RRef``, and | |||
425 | subsequent invocations no longer trigger RPC. Note that this is | |||
426 | true regardless of the ``blocking`` argument of subsequent | |||
427 | calls. | |||
428 | ||||
429 | Args: | |||
430 | rref (torch.distributed.rpc.RRef): The RRef to get type of. | |||
431 | timeout (float, optional): Timeout, in seconds for | |||
432 | ``_get_type``. If the call does not complete within | |||
433 | this timeframe, an exception indicating so will be | |||
434 | raised. If this argument is not provided, the default | |||
435 | RPC timeout will be used. | |||
436 | blocking (bool, optional): Whether to synchronously wait on | |||
437 | the RPC triggered by the first call and return the | |||
438 | type. If ``False``, will return a future. Default is | |||
439 | ``True``. | |||
440 | )") | |||
441 | .def( | |||
442 | "_get_future", | |||
443 | [](const PyRRef& self) { | |||
444 | return std::make_shared<jit::PythonFutureWrapper>( | |||
445 | self.getFuture()); | |||
446 | }, | |||
447 | py::call_guard<py::gil_scoped_release>(), | |||
448 | R"( | |||
449 | Returns the future that corresponds to the creation of this RRef | |||
450 | on the remote node. This is for internal use cases such as profiling | |||
451 | only. | |||
452 | )") | |||
453 | .def( | |||
454 | "_get_profiling_future", | |||
455 | [](const PyRRef& self) { | |||
456 | return std::make_shared<jit::PythonFutureWrapper>( | |||
457 | self.getProfilingFuture()); | |||
458 | }, | |||
459 | py::call_guard<py::gil_scoped_acquire>(), | |||
460 | R"( | |||
461 | Returns future that completes when the profiling event corresponding | |||
462 | to the creation of this RRef on the remote node has been recorded. | |||
463 | )") | |||
464 | .def( | |||
465 | "_set_profiling_future", | |||
466 | [](PyRRef& self, | |||
467 | const std::shared_ptr<jit::PythonFutureWrapper>& | |||
468 | wrappedFuture) { | |||
469 | self.setProfilingFuture(wrappedFuture->fut); | |||
470 | }, | |||
471 | py::call_guard<py::gil_scoped_acquire>(), | |||
472 | R"( | |||
473 | Set future that is completed when the profiling event corresponding | |||
474 | to the creation of this RRef on the remote node has been recorded. | |||
475 | )") | |||
476 | .def( | |||
477 | "backward", | |||
478 | [](PyRRef& self, | |||
479 | int64_t dist_autograd_ctx_id, | |||
480 | bool retain_graph) { | |||
481 | self.backward(dist_autograd_ctx_id, retain_graph); | |||
482 | }, | |||
483 | py::arg("dist_autograd_ctx_id") = -1, | |||
484 | py::arg("retain_graph") = false, | |||
485 | py::call_guard<py::gil_scoped_release>(), | |||
486 | R"( | |||
487 | Runs the backward pass using the RRef as the root of the | |||
488 | backward pass. If ``dist_autograd_ctx_id`` is provided, | |||
489 | we perform a distributed backward pass using the provided | |||
490 | ctx_id starting from the owner of the RRef. In this case, | |||
491 | :meth:`~torch.distributed.autograd.get_gradients` should be | |||
492 | used to retrieve the gradients. If ``dist_autograd_ctx_id`` | |||
493 | is ``None``, it is assumed that this is a local autograd graph | |||
494 | and we only perform a local backward pass. In the local case, | |||
495 | the node calling this API has to be the owner of the RRef. | |||
496 | The value of the RRef is expected to be a scalar Tensor. | |||
497 | ||||
498 | Args: | |||
499 | dist_autograd_ctx_id (int, optional): The distributed | |||
500 | autograd context id for which we should retrieve the | |||
501 | gradients (default: -1). | |||
502 | retain_graph(bool, optional): If ``False``, the graph used to | |||
503 | compute the grad will be freed. Note that in nearly all | |||
504 | cases setting this option to ``True`` is not needed and | |||
505 | often can be worked around in a much more efficient way. | |||
506 | Usually, you need to set this to ``True`` to run backward | |||
507 | multiple times (default: False). | |||
508 | ||||
509 | Example:: | |||
510 | >>> import torch.distributed.autograd as dist_autograd | |||
511 | >>> with dist_autograd.context() as context_id: | |||
512 | >>> rref.backward(context_id) | |||
513 | )") | |||
514 | // not releasing GIL to avoid context switch | |||
515 | .def("__repr__", &PyRRef::str); | |||
516 | ||||
517 | shared_ptr_class_<ProcessGroupRpcBackendOptions>( | |||
518 | module, | |||
519 | "ProcessGroupRpcBackendOptions", | |||
520 | rpcBackendOptions, | |||
521 | R"( | |||
522 | The backend options class for ``ProcessGroupAgent``, which is derived | |||
523 | from ``RpcBackendOptions``. | |||
524 | ||||
525 | Args: | |||
526 | num_send_recv_threads (int, optional): The number of threads in | |||
527 | the thread-pool used by ``ProcessGroupAgent`` (default: 4). | |||
528 | rpc_timeout (float, optional): The default timeout, in seconds, | |||
529 | for RPC requests (default: 60 seconds). If the | |||
530 | RPC has not completed in this timeframe, an exception | |||
531 | indicating so will be raised. Callers can override this | |||
532 | timeout for individual RPCs in | |||
533 | :meth:`~torch.distributed.rpc.rpc_sync` and | |||
534 | :meth:`~torch.distributed.rpc.rpc_async` if necessary. | |||
535 | init_method (str, optional): The URL to initialize | |||
536 | ``ProcessGroupGloo`` (default: ``env://``). | |||
537 | )") | |||
538 | .def( | |||
539 | py::init<int, float, std::string>(), | |||
540 | py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads, | |||
541 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
542 | py::arg("init_method") = kDefaultInitMethod) | |||
543 | .def_readwrite( | |||
544 | "num_send_recv_threads", | |||
545 | &ProcessGroupRpcBackendOptions::numSendRecvThreads, | |||
546 | R"( | |||
547 | The number of threads in the thread-pool used by ProcessGroupAgent. | |||
548 | )"); | |||
549 | ||||
550 | module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") = | |||
551 | py::cast(kDefaultNumSendRecvThreads); | |||
552 | ||||
553 | shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent) | |||
554 | .def(py::init([](const c10::intrusive_ptr<::c10d::Store>& store, | |||
555 | std::string workerName, | |||
556 | const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, | |||
557 | int numSendRecvThreads, | |||
558 | std::chrono::milliseconds rpcTimeout) { | |||
559 | return std::shared_ptr<ProcessGroupAgent>( | |||
560 | new ProcessGroupAgent( | |||
561 | store, | |||
562 | std::move(workerName), | |||
563 | pg, | |||
564 | numSendRecvThreads, | |||
565 | rpcTimeout, | |||
566 | std::make_unique<RequestCallbackImpl>()), | |||
567 | impl::destroy_without_gil<ProcessGroupAgent>); | |||
568 | })) | |||
569 | .def( | |||
570 | "get_worker_info", | |||
571 | (const WorkerInfo& (ProcessGroupAgent::*)(void) const) & | |||
572 | RpcAgent::getWorkerInfo, | |||
573 | py::call_guard<py::gil_scoped_release>()) | |||
574 | .def( | |||
575 | "get_worker_info", | |||
576 | (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) & | |||
577 | ProcessGroupAgent::getWorkerInfo, | |||
578 | py::call_guard<py::gil_scoped_release>()) | |||
579 | .def( | |||
580 | "get_worker_info", | |||
581 | (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) & | |||
582 | ProcessGroupAgent::getWorkerInfo, | |||
583 | py::call_guard<py::gil_scoped_release>()) | |||
584 | .def( | |||
585 | "get_worker_infos", | |||
586 | (std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) & | |||
587 | ProcessGroupAgent::getWorkerInfos, | |||
588 | py::call_guard<py::gil_scoped_release>()) | |||
589 | .def( | |||
590 | "_get_device_map", | |||
591 | (DeviceMap(ProcessGroupAgent::*)(const WorkerInfo& dst) const) & | |||
592 | ProcessGroupAgent::getDeviceMap, | |||
593 | py::call_guard<py::gil_scoped_release>()) | |||
594 | .def( | |||
595 | "join", | |||
596 | &ProcessGroupAgent::join, | |||
597 | py::call_guard<py::gil_scoped_release>(), | |||
598 | py::arg("shutdown") = false) | |||
599 | .def( | |||
600 | "shutdown", | |||
601 | &ProcessGroupAgent::shutdown, | |||
602 | py::call_guard<py::gil_scoped_release>()) | |||
603 | .def( | |||
604 | "sync", | |||
605 | &ProcessGroupAgent::sync, | |||
606 | py::call_guard<py::gil_scoped_release>()); | |||
607 | ||||
608 | #ifdef USE_TENSORPIPE1 | |||
609 | ||||
610 | // Base class: torch.distributed.rpc.RpcBackendOptions. | |||
611 | py::class_<TensorPipeRpcBackendOptions>( | |||
612 | module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions) | |||
613 | .def( | |||
614 | py::init< | |||
615 | int, | |||
616 | optional<std::vector<std::string>>, | |||
617 | optional<std::vector<std::string>>, | |||
618 | float, | |||
619 | std::string, | |||
620 | std::unordered_map<std::string, DeviceMap>, | |||
621 | std::vector<c10::Device>>(), | |||
622 | py::arg("num_worker_threads") = kDefaultNumWorkerThreads, | |||
623 | py::arg("_transports") = optional<std::vector<std::string>>(), | |||
624 | py::arg("_channels") = optional<std::vector<std::string>>(), | |||
625 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
626 | py::arg("init_method") = kDefaultInitMethod, | |||
627 | py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(), | |||
628 | py::arg("devices") = std::vector<c10::Device>()) | |||
629 | .def_readwrite( | |||
630 | "num_worker_threads", | |||
631 | &TensorPipeRpcBackendOptions::numWorkerThreads, | |||
632 | R"( | |||
633 | The number of threads in the thread-pool used by | |||
634 | :class:`~torch.distributed.rpc.TensorPipeAgent` to execute | |||
635 | requests. | |||
636 | )") | |||
637 | .def_readwrite( | |||
638 | "device_maps", | |||
639 | &TensorPipeRpcBackendOptions::deviceMaps, | |||
640 | R"(The device map locations.)") | |||
641 | .def_readwrite( | |||
642 | "devices", | |||
643 | &TensorPipeRpcBackendOptions::devices, | |||
644 | R"(All devices used by the local agent.)") | |||
645 | .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap); | |||
646 | ||||
647 | module.attr("_DEFAULT_NUM_WORKER_THREADS") = | |||
648 | py::cast(kDefaultNumWorkerThreads); | |||
649 | ||||
650 | shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent) | |||
651 | .def( | |||
652 | py::init( | |||
653 | [](const c10::intrusive_ptr<::c10d::Store>& store, | |||
654 | std::string selfName, | |||
655 | worker_id_t selfId, | |||
656 | int worldSize, | |||
657 | c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, | |||
658 | TensorPipeRpcBackendOptions opts, | |||
659 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, | |||
660 | std::vector<c10::Device> devices) { | |||
661 | return std::shared_ptr<TensorPipeAgent>( | |||
662 | new TensorPipeAgent( | |||
663 | store, | |||
664 | std::move(selfName), | |||
665 | selfId, | |||
666 | worldSize, | |||
667 | std::move(processGroup), | |||
668 | std::move(opts), | |||
669 | std::move(reverseDeviceMaps), | |||
670 | std::move(devices), | |||
671 | std::make_unique<RequestCallbackImpl>()), | |||
672 | impl::destroy_without_gil<TensorPipeAgent>); | |||
673 | }), | |||
674 | py::arg("store"), | |||
675 | py::arg("name"), | |||
676 | py::arg("rank"), | |||
677 | py::arg("world_size"), | |||
678 | py::arg("process_group"), | |||
679 | py::arg("rpc_backend_options"), | |||
680 | py::arg("reverse_device_maps"), | |||
681 | py::arg("devices")) | |||
682 | .def( | |||
683 | "join", | |||
684 | &TensorPipeAgent::join, | |||
685 | py::call_guard<py::gil_scoped_release>(), | |||
686 | py::arg("shutdown") = false) | |||
687 | .def( | |||
688 | "shutdown", | |||
689 | &TensorPipeAgent::shutdown, | |||
690 | py::call_guard<py::gil_scoped_release>()) | |||
691 | .def( | |||
692 | "get_worker_info", | |||
693 | (const WorkerInfo& (TensorPipeAgent::*)(void) const) & | |||
694 | RpcAgent::getWorkerInfo, | |||
695 | py::call_guard<py::gil_scoped_release>()) | |||
696 | .def( | |||
697 | "get_worker_info", | |||
698 | (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) & | |||
699 | TensorPipeAgent::getWorkerInfo, | |||
700 | py::call_guard<py::gil_scoped_release>()) | |||
701 | .def( | |||
702 | "get_worker_info", | |||
703 | (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) & | |||
704 | TensorPipeAgent::getWorkerInfo, | |||
705 | py::call_guard<py::gil_scoped_release>()) | |||
706 | .def( | |||
707 | "get_worker_infos", | |||
708 | (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) & | |||
709 | TensorPipeAgent::getWorkerInfos, | |||
710 | py::call_guard<py::gil_scoped_release>()) | |||
711 | .def( | |||
712 | "_get_device_map", | |||
713 | (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) & | |||
714 | TensorPipeAgent::getDeviceMap, | |||
715 | py::call_guard<py::gil_scoped_release>()); | |||
716 | ||||
717 | #endif // USE_TENSORPIPE | |||
718 | ||||
719 | module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet); | |||
720 | ||||
721 | module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent); | |||
722 | ||||
723 | module.def( | |||
724 | "_set_and_start_rpc_agent", | |||
725 | [](const std::shared_ptr<RpcAgent>& rpcAgent) { | |||
726 | RpcAgent::setCurrentRpcAgent(rpcAgent); | |||
727 | // Initializing typeResolver inside RpcAgent constructor will make | |||
728 | // RpcAgent have python dependency. To avoid RpcAgent to have python | |||
729 | // dependency, setTypeResolver() here. | |||
730 | std::shared_ptr<TypeResolver> typeResolver = | |||
731 | std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { | |||
732 | auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr( | |||
733 | qn.qualifiedName()); | |||
734 | return c10::StrongTypePtr( | |||
735 | PythonRpcHandler::getInstance().jitCompilationUnit(), | |||
736 | std::move(typePtr)); | |||
737 | }); | |||
738 | rpcAgent->setTypeResolver(typeResolver); | |||
739 | rpcAgent->start(); | |||
740 | }, | |||
741 | py::call_guard<py::gil_scoped_release>()); | |||
742 | ||||
743 | module.def("_reset_current_rpc_agent", []() { | |||
744 | RpcAgent::setCurrentRpcAgent(nullptr); | |||
745 | }); | |||
746 | ||||
747 | module.def( | |||
748 | "_delete_all_user_and_unforked_owner_rrefs", | |||
749 | [](std::chrono::milliseconds timeoutMillis) { | |||
750 | RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis); | |||
751 | }, | |||
752 | py::arg("timeout") = kDeleteAllUsersTimeout, | |||
753 | py::call_guard<py::gil_scoped_release>()); | |||
754 | ||||
755 | module.def("_destroy_rref_context", [](bool ignoreRRefLeak) { | |||
756 | // NB: do not release GIL in the function. The destroyInstance() method | |||
757 | // returns a list of deleted OwnerRRefs that hold py::object instances. | |||
758 | // Clearing those OwnerRRefs are likely to trigger Python deref, which | |||
759 | // requires GIL. | |||
760 | RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear(); | |||
761 | }); | |||
762 | ||||
763 | module.def("_rref_context_get_debug_info", []() { | |||
764 | return RRefContext::getInstance().getDebugInfo(); | |||
765 | }); | |||
766 | ||||
767 | module.def( | |||
768 | "_cleanup_python_rpc_handler", | |||
769 | []() { PythonRpcHandler::getInstance().cleanup(); }, | |||
770 | py::call_guard<py::gil_scoped_release>()); | |||
771 | ||||
772 | module.def( | |||
773 | "_invoke_rpc_builtin", | |||
774 | [](const WorkerInfo& dst, | |||
775 | const std::string& opName, | |||
776 | const float rpcTimeoutSeconds, | |||
777 | const py::args& args, | |||
778 | const py::kwargs& kwargs) { | |||
779 | return std::make_shared<jit::PythonFutureWrapper>( | |||
780 | pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); | |||
781 | }, | |||
782 | py::call_guard<py::gil_scoped_acquire>()); | |||
783 | ||||
784 | module.def( | |||
785 | "_invoke_rpc_python_udf", | |||
786 | [](const WorkerInfo& dst, | |||
787 | std::string& pickledPythonUDF, | |||
788 | std::vector<torch::Tensor>& tensors, | |||
789 | const float rpcTimeoutSeconds, | |||
790 | const bool isAsyncExecution) { | |||
791 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf( | |||
792 | dst, | |||
793 | pickledPythonUDF, | |||
794 | tensors, | |||
795 | rpcTimeoutSeconds, | |||
796 | isAsyncExecution)); | |||
797 | }, | |||
798 | py::call_guard<py::gil_scoped_release>()); | |||
799 | ||||
800 | module.def( | |||
801 | "_invoke_rpc_torchscript", | |||
802 | [](const std::string& dstWorkerName, | |||
803 | const std::string& qualifiedNameStr, | |||
804 | const py::tuple& argsTuple, | |||
805 | const py::dict& kwargsDict, | |||
806 | const float rpcTimeoutSeconds, | |||
807 | const bool isAsyncExecution) { | |||
808 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript( | |||
809 | dstWorkerName, | |||
810 | qualifiedNameStr, | |||
811 | argsTuple, | |||
812 | kwargsDict, | |||
813 | rpcTimeoutSeconds, | |||
814 | isAsyncExecution)); | |||
815 | }, | |||
816 | py::call_guard<py::gil_scoped_release>()); | |||
817 | ||||
818 | module.def( | |||
819 | "_invoke_remote_builtin", | |||
820 | &pyRemoteBuiltin, | |||
821 | py::call_guard<py::gil_scoped_acquire>()); | |||
822 | ||||
823 | module.def( | |||
824 | "_invoke_remote_python_udf", | |||
825 | &pyRemotePythonUdf, | |||
826 | py::call_guard<py::gil_scoped_release>()); | |||
827 | ||||
828 | module.def( | |||
829 | "_invoke_remote_torchscript", | |||
830 | &pyRemoteTorchscript, | |||
831 | py::call_guard<py::gil_scoped_release>()); | |||
832 | ||||
833 | module.def( | |||
834 | "get_rpc_timeout", | |||
835 | []() { | |||
836 | return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() / | |||
837 | kSecToMsConversion; | |||
838 | }, | |||
839 | R"( | |||
840 | Retrieve the default timeout for all RPCs that was set during RPC initialization. | |||
841 | The returned value will be in seconds. | |||
842 | Returns: | |||
843 | ``float`` indicating the RPC timeout in seconds. | |||
844 | )"); | |||
845 | ||||
846 | module.def( | |||
847 | "enable_gil_profiling", | |||
848 | [](bool flag) { | |||
849 | RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag); | |||
850 | }, | |||
851 | R"( | |||
852 | Set whether GIL wait times should be enabled or not. This incurs a slight | |||
853 | overhead cost. Default is disabled for performance reasons. | |||
854 | ||||
855 | Args: | |||
856 | flag (bool): True to set GIL profiling, False to disable. | |||
857 | )"); | |||
858 | ||||
859 | module.def( | |||
860 | "_set_rpc_timeout", | |||
861 | [](const float rpcTimeoutSeconds) { | |||
862 | auto rpcTimeout = std::chrono::milliseconds( | |||
863 | static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion)); | |||
864 | RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout); | |||
865 | }, | |||
866 | R"( | |||
867 | Set the default timeout for all RPCs. The input unit is expected to be | |||
868 | in seconds. If an RPC is not completed within this time, an exception | |||
869 | indicating it has timed out will be raised. To control timeout for | |||
870 | specific RPCs, a timeout parameter can be passed into | |||
871 | :meth:`~torch.distributed.rpc.rpc_sync` and | |||
872 | :meth:`~torch.distributed.rpc.rpc_async`. | |||
873 | ||||
874 | Args: | |||
875 | rpcTimeoutSeconds (float): Timeout value in seconds. | |||
876 | )"); | |||
877 | ||||
878 | module.def( | |||
879 | "_enable_server_process_global_profiler", | |||
880 | &profiler::processglobal::enableServer); | |||
881 | module.def( | |||
882 | "_disable_server_process_global_profiler", | |||
883 | &profiler::processglobal::disableServer); | |||
884 | ||||
885 | module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId); | |||
886 | ||||
887 | py::class_< | |||
888 | RemoteProfilerManager, | |||
889 | std::unique_ptr<RemoteProfilerManager, py::nodelete>>( | |||
890 | module, "RemoteProfilerManager") | |||
891 | .def("set_current_profiling_key", [](const std::string& key) { | |||
892 | auto& inst = RemoteProfilerManager::getInstance(); | |||
893 | inst.setCurrentKey(key); | |||
894 | }); | |||
895 | ||||
896 | module.def( | |||
897 | "_enable_jit_rref_pickle", | |||
898 | &enableJitRRefPickle, | |||
899 | R"( | |||
900 | Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with | |||
901 | pickled RRefs out of RPC contexts. | |||
902 | ||||
903 | ||||
904 | .. warning:: | |||
905 | This is dangerous. If the module contains RRefs, the pickled | |||
906 | result must be sent over RPC and get unpickled on the receiving side | |||
907 | to restore the module. Otherwise, there will be RRef leaks, which | |||
908 | can potentially lead to program hang. When using this API, it is | |||
909 | applications responsibility to make sure that the above assumption | |||
910 | always holds. | |||
911 | )"); | |||
912 | module.def("_disable_jit_rref_pickle", &disableJitRRefPickle); | |||
913 | ||||
914 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
915 | } | |||
916 | ||||
917 | } // namespace | |||
918 | ||||
919 | static PyMethodDef methods[] = { // NOLINT | |||
920 | {"_rpc_init", rpc_init, METH_NOARGS0x0004, nullptr}, | |||
921 | {nullptr, nullptr, 0, nullptr}}; | |||
922 | ||||
923 | PyMethodDef* python_functions() { | |||
924 | return methods; | |||
925 | } | |||
926 | ||||
927 | } // namespace rpc | |||
928 | } // namespace distributed | |||
929 | } // namespace torch |
1 | #ifndef PyImport_ImportModule |
2 | struct _object; |
3 | typedef struct _object PyObject; |
4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
5 | PyObject* PyImport_ImportModule(const char *name) { |
6 | return clang_analyzer_PyObject_New_Reference(); |
7 | } |
8 | #else |
9 | #warning "API PyImport_ImportModule is defined as a macro." |
10 | #endif |