| File: | build/../torch/csrc/distributed/rpc/init.cpp |
| Warning: | line 36, column 20 PyObject ownership leak with reference count of 1 |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | #include <torch/csrc/python_headers.h> | |||
| 2 | ||||
| 3 | #include <torch/csrc/distributed/rpc/process_group_agent.h> | |||
| 4 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> | |||
| 5 | #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> | |||
| 6 | #include <torch/csrc/distributed/rpc/py_rref.h> | |||
| 7 | #include <torch/csrc/distributed/rpc/python_functions.h> | |||
| 8 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> | |||
| 9 | #include <torch/csrc/distributed/rpc/request_callback_impl.h> | |||
| 10 | #include <torch/csrc/distributed/rpc/rpc_agent.h> | |||
| 11 | #include <torch/csrc/distributed/rpc/rref_context.h> | |||
| 12 | #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> | |||
| 13 | #include <torch/csrc/distributed/rpc/torchscript_functions.h> | |||
| 14 | #include <torch/csrc/distributed/rpc/types.h> | |||
| 15 | #include <torch/csrc/jit/python/pybind_utils.h> | |||
| 16 | #include <torch/csrc/utils/object_ptr.h> | |||
| 17 | #include <torch/csrc/utils/pybind.h> | |||
| 18 | #include <torch/types.h> | |||
| 19 | ||||
| 20 | #include <pybind11/chrono.h> | |||
| 21 | #include <pybind11/operators.h> | |||
| 22 | ||||
| 23 | namespace torch { | |||
| 24 | namespace distributed { | |||
| 25 | namespace rpc { | |||
| 26 | ||||
| 27 | namespace { | |||
| 28 | ||||
| 29 | constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000); | |||
| 30 | ||||
| 31 | template <typename T> | |||
| 32 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; | |||
| 33 | ||||
| 34 | PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { | |||
| 35 | auto rpc_module = | |||
| 36 | THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc")); | |||
| ||||
| ||||
| 37 | if (!rpc_module) { | |||
| 38 | throw python_error(); | |||
| 39 | } | |||
| 40 | ||||
| 41 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); | |||
| 42 | if (!torch_C_module) { | |||
| 43 | throw python_error(); | |||
| 44 | } | |||
| 45 | ||||
| 46 | auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); | |||
| 47 | auto m = | |||
| 48 | torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); | |||
| 49 | ||||
| 50 | auto module = py::handle(m).cast<py::module>(); | |||
| 51 | ||||
| 52 | auto rpcBackendOptions = | |||
| 53 | shared_ptr_class_<RpcBackendOptions>( | |||
| 54 | module, | |||
| 55 | "RpcBackendOptions", | |||
| 56 | R"(An abstract structure encapsulating the options passed into the RPC | |||
| 57 | backend. An instance of this class can be passed in to | |||
| 58 | :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC | |||
| 59 | with specific configurations, such as the RPC timeout and | |||
| 60 | ``init_method`` to be used. )") | |||
| 61 | .def(py::init<>()) | |||
| 62 | .def( | |||
| 63 | py::init<float, std::string>(), | |||
| 64 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
| 65 | py::arg("init_method") = kDefaultInitMethod) | |||
| 66 | .def_readwrite( | |||
| 67 | "rpc_timeout", | |||
| 68 | &RpcBackendOptions::rpcTimeoutSeconds, | |||
| 69 | R"(A float indicating the timeout to use for all | |||
| 70 | RPCs. If an RPC does not complete in this timeframe, it will | |||
| 71 | complete with an exception indicating that it has timed out.)") | |||
| 72 | .def_readwrite( | |||
| 73 | "init_method", | |||
| 74 | &RpcBackendOptions::initMethod, | |||
| 75 | R"(URL specifying how to initialize the process group. | |||
| 76 | Default is ``env://``)"); | |||
| 77 | ||||
| 78 | // The following C++ constants need to be cast so they can be used from | |||
| 79 | // python. | |||
| 80 | module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds); | |||
| 81 | module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout); | |||
| 82 | module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod); | |||
| 83 | ||||
| 84 | auto workerInfo = | |||
| 85 | shared_ptr_class_<WorkerInfo>( | |||
| 86 | module, | |||
| 87 | "WorkerInfo", | |||
| 88 | R"(A structure that encapsulates information of a worker in the system. | |||
| 89 | Contains the name and ID of the worker. This class is not meant to | |||
| 90 | be constructed directly, rather, an instance can be retrieved | |||
| 91 | through :meth:`~torch.distributed.rpc.get_worker_info` and the | |||
| 92 | result can be passed in to functions such as | |||
| 93 | :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`, | |||
| 94 | :meth:`~torch.distributed.rpc.remote` to avoid copying a string on | |||
| 95 | every invocation.)") | |||
| 96 | .def( | |||
| 97 | py::init<std::string, worker_id_t>(), | |||
| 98 | py::arg("name"), | |||
| 99 | py::arg("id")) | |||
| 100 | .def_readonly( | |||
| 101 | "name", &WorkerInfo::name_, R"(The name of the worker.)") | |||
| 102 | .def_readonly( | |||
| 103 | "id", | |||
| 104 | &WorkerInfo::id_, | |||
| 105 | R"(Globally unique id to identify the worker.)") | |||
| 106 | .def("__eq__", &WorkerInfo::operator==, py::is_operator()) | |||
| 107 | // pybind11 suggests the syntax .def(hash(py::self)), with the | |||
| 108 | // unqualified "hash" function call. However the | |||
| 109 | // argument-dependent lookup for the function "hash" doesn't get | |||
| 110 | // triggered in this context because it conflicts with the struct | |||
| 111 | // c10::hash, so we need to use the qualified name | |||
| 112 | // py::detail::hash, which unfortunately is in a detail namespace. | |||
| 113 | .def(py::detail::hash(py::self)) // NOLINT | |||
| 114 | .def("__repr__", [](const WorkerInfo& workerInfo) { | |||
| 115 | std::ostringstream os; | |||
| 116 | os << workerInfo; | |||
| 117 | return os.str(); | |||
| 118 | }); | |||
| 119 | ||||
| 120 | auto rpcAgent = | |||
| 121 | shared_ptr_class_<RpcAgent>(module, "RpcAgent") | |||
| 122 | .def( | |||
| 123 | "join", | |||
| 124 | &RpcAgent::join, | |||
| 125 | py::call_guard<py::gil_scoped_release>(), | |||
| 126 | py::arg("shutdown") = false) | |||
| 127 | .def( | |||
| 128 | "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>()) | |||
| 129 | .def( | |||
| 130 | "shutdown", | |||
| 131 | &RpcAgent::shutdown, | |||
| 132 | py::call_guard<py::gil_scoped_release>()) | |||
| 133 | .def( | |||
| 134 | "get_worker_info", | |||
| 135 | (const WorkerInfo& (RpcAgent::*)(void) const) & | |||
| 136 | RpcAgent::getWorkerInfo, | |||
| 137 | py::call_guard<py::gil_scoped_release>()) | |||
| 138 | .def( | |||
| 139 | "get_worker_info", | |||
| 140 | (const WorkerInfo& (RpcAgent::*)(const std::string&) const) & | |||
| 141 | RpcAgent::getWorkerInfo, | |||
| 142 | py::call_guard<py::gil_scoped_release>()) | |||
| 143 | .def( | |||
| 144 | "get_worker_infos", | |||
| 145 | &RpcAgent::getWorkerInfos, | |||
| 146 | py::call_guard<py::gil_scoped_release>()) | |||
| 147 | .def( | |||
| 148 | "_get_device_map", | |||
| 149 | &RpcAgent::getDeviceMap, | |||
| 150 | py::call_guard<py::gil_scoped_release>()) | |||
| 151 | .def( | |||
| 152 | "get_debug_info", | |||
| 153 | &RpcAgent::getDebugInfo, | |||
| 154 | py::call_guard<py::gil_scoped_release>()) | |||
| 155 | .def( | |||
| 156 | "get_metrics", | |||
| 157 | &RpcAgent::getMetrics, | |||
| 158 | py::call_guard<py::gil_scoped_release>()); | |||
| 159 | ||||
| 160 | auto pyRRef = | |||
| 161 | shared_ptr_class_<PyRRef>(module, "PyRRef", R"( | |||
| 162 | A class encapsulating a reference to a value of some type on a remote | |||
| 163 | worker. This handle will keep the referenced remote value alive on the | |||
| 164 | worker. A ``UserRRef`` will be deleted when 1) no references to it in | |||
| 165 | both the application code and in the local RRef context, or 2) the | |||
| 166 | application has called a graceful shutdown. Invoking methods on a | |||
| 167 | deleted RRef leads to undefined behaviors. RRef implementation only | |||
| 168 | offers best-effort error detection, and applications should not use | |||
| 169 | ``UserRRefs`` after ``rpc.shutdown()``. | |||
| 170 | ||||
| 171 | .. warning:: | |||
| 172 | RRefs can only be serialized and deserialized by the RPC module. | |||
| 173 | Serializing and deserializing RRefs without RPC (e.g., Python | |||
| 174 | pickle, torch :meth:`~torch.save` / :meth:`~torch.load`, | |||
| 175 | JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will | |||
| 176 | lead to errors. | |||
| 177 | ||||
| 178 | Args: | |||
| 179 | value (object): The value to be wrapped by this RRef. | |||
| 180 | type_hint (Type, optional): Python type that should be passed to | |||
| 181 | ``TorchScript`` compiler as type hint for ``value``. | |||
| 182 | ||||
| 183 | Example:: | |||
| 184 | Following examples skip RPC initialization and shutdown code | |||
| 185 | for simplicity. Refer to RPC docs for those details. | |||
| 186 | ||||
| 187 | 1. Create an RRef using rpc.remote | |||
| 188 | ||||
| 189 | >>> import torch | |||
| 190 | >>> import torch.distributed.rpc as rpc | |||
| 191 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) | |||
| 192 | >>> # get a copy of value from the RRef | |||
| 193 | >>> x = rref.to_here() | |||
| 194 | ||||
| 195 | 2. Create an RRef from a local object | |||
| 196 | ||||
| 197 | >>> import torch | |||
| 198 | >>> from torch.distributed.rpc import RRef | |||
| 199 | >>> x = torch.zeros(2, 2) | |||
| 200 | >>> rref = RRef(x) | |||
| 201 | ||||
| 202 | 3. Share an RRef with other workers | |||
| 203 | ||||
| 204 | >>> # On both worker0 and worker1: | |||
| 205 | >>> def f(rref): | |||
| 206 | >>> return rref.to_here() + 1 | |||
| 207 | ||||
| 208 | >>> # On worker0: | |||
| 209 | >>> import torch | |||
| 210 | >>> import torch.distributed.rpc as rpc | |||
| 211 | >>> from torch.distributed.rpc import RRef | |||
| 212 | >>> rref = RRef(torch.zeros(2, 2)) | |||
| 213 | >>> # the following RPC shares the rref with worker1, reference | |||
| 214 | >>> # count is automatically updated. | |||
| 215 | >>> rpc.rpc_sync("worker1", f, args=(rref,)) | |||
| 216 | )") | |||
| 217 | .def( | |||
| 218 | py::init<const py::object&, const py::object&>(), | |||
| 219 | py::arg("value"), | |||
| 220 | py::arg("type_hint") = py::none()) | |||
| 221 | .def( | |||
| 222 | // not releasing GIL here to avoid context switch on getters | |||
| 223 | "is_owner", | |||
| 224 | &PyRRef::isOwner, | |||
| 225 | R"( | |||
| 226 | Returns whether or not the current node is the owner of this | |||
| 227 | ``RRef``. | |||
| 228 | )") | |||
| 229 | .def( | |||
| 230 | "confirmed_by_owner", | |||
| 231 | &PyRRef::confirmedByOwner, | |||
| 232 | R"( | |||
| 233 | Returns whether this ``RRef`` has been confirmed by the owner. | |||
| 234 | ``OwnerRRef`` always returns true, while ``UserRRef`` only | |||
| 235 | returns true when the owner knowns about this ``UserRRef``. | |||
| 236 | )") | |||
| 237 | .def( | |||
| 238 | // not releasing GIL here to avoid context switch on getters | |||
| 239 | "owner", | |||
| 240 | &PyRRef::owner, | |||
| 241 | R"( | |||
| 242 | Returns worker information of the node that owns this ``RRef``. | |||
| 243 | )") | |||
| 244 | .def( | |||
| 245 | // not releasing GIL here to avoid context switch on getters | |||
| 246 | "owner_name", | |||
| 247 | &PyRRef::ownerName, | |||
| 248 | R"( | |||
| 249 | Returns worker name of the node that owns this ``RRef``. | |||
| 250 | )") | |||
| 251 | .def( | |||
| 252 | "to_here", | |||
| 253 | &PyRRef::toHere, | |||
| 254 | py::arg("timeout") = py::cast(kUnsetRpcTimeout), | |||
| 255 | py::call_guard<py::gil_scoped_release>(), | |||
| 256 | R"( | |||
| 257 | Blocking call that copies the value of the RRef from the owner | |||
| 258 | to the local node and returns it. If the current node is the | |||
| 259 | owner, returns a reference to the local value. | |||
| 260 | ||||
| 261 | Args: | |||
| 262 | timeout (float, optional): Timeout for ``to_here``. If | |||
| 263 | the call does not complete within this timeframe, an | |||
| 264 | exception indicating so will be raised. If this | |||
| 265 | argument is not provided, the default RPC timeout | |||
| 266 | (60s) will be used. | |||
| 267 | )") | |||
| 268 | .def( | |||
| 269 | "local_value", | |||
| 270 | &PyRRef::localValue, | |||
| 271 | py::call_guard<py::gil_scoped_release>(), | |||
| 272 | R"( | |||
| 273 | If the current node is the owner, returns a reference to the | |||
| 274 | local value. Otherwise, throws an exception. | |||
| 275 | )") | |||
| 276 | .def( | |||
| 277 | "rpc_sync", | |||
| 278 | [](const PyRRef& self, float timeoutSeconds) { | |||
| 279 | return self.createRRefProxy( | |||
| 280 | RRefProxyType::RPC_SYNC, timeoutSeconds); | |||
| 281 | }, | |||
| 282 | py::arg("timeout") = kUnsetRpcTimeout, | |||
| 283 | py::call_guard<py::gil_scoped_release>(), | |||
| 284 | R"( | |||
| 285 | Create a helper proxy to easily launch an ``rpc_sync`` using | |||
| 286 | the owner of the RRef as the destination to run functions on | |||
| 287 | the object referenced by this RRef. More specifically, | |||
| 288 | ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as | |||
| 289 | the following: | |||
| 290 | ||||
| 291 | >>> def run(rref, func_name, args, kwargs): | |||
| 292 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
| 293 | >>> | |||
| 294 | >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
| 295 | ||||
| 296 | Args: | |||
| 297 | timeout (float, optional): Timeout for ``rref.rpc_sync()``. | |||
| 298 | If the call does not complete within this timeframe, an | |||
| 299 | exception indicating so will be raised. If this argument | |||
| 300 | is not provided, the default RPC timeout will be used. | |||
| 301 | ||||
| 302 | Example:: | |||
| 303 | >>> from torch.distributed import rpc | |||
| 304 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
| 305 | >>> rref.rpc_sync().size() # returns torch.Size([2, 2]) | |||
| 306 | >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]]) | |||
| 307 | )") | |||
| 308 | .def( | |||
| 309 | "rpc_async", | |||
| 310 | [](const PyRRef& self, float timeoutSeconds) { | |||
| 311 | return self.createRRefProxy( | |||
| 312 | RRefProxyType::RPC_ASYNC, timeoutSeconds); | |||
| 313 | }, | |||
| 314 | py::arg("timeout") = kUnsetRpcTimeout, | |||
| 315 | py::call_guard<py::gil_scoped_release>(), | |||
| 316 | R"( | |||
| 317 | Create a helper proxy to easily launch an ``rpc_async`` using | |||
| 318 | the owner of the RRef as the destination to run functions on | |||
| 319 | the object referenced by this RRef. More specifically, | |||
| 320 | ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as | |||
| 321 | the following: | |||
| 322 | ||||
| 323 | >>> def run(rref, func_name, args, kwargs): | |||
| 324 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
| 325 | >>> | |||
| 326 | >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
| 327 | ||||
| 328 | Args: | |||
| 329 | timeout (float, optional): Timeout for ``rref.rpc_async()``. | |||
| 330 | If the call does not complete within this timeframe, an | |||
| 331 | exception indicating so will be raised. If this argument | |||
| 332 | is not provided, the default RPC timeout will be used. | |||
| 333 | ||||
| 334 | Example:: | |||
| 335 | >>> from torch.distributed import rpc | |||
| 336 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
| 337 | >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2]) | |||
| 338 | >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]]) | |||
| 339 | )") | |||
| 340 | .def( | |||
| 341 | "remote", | |||
| 342 | [](const PyRRef& self, float timeoutSeconds) { | |||
| 343 | return self.createRRefProxy( | |||
| 344 | RRefProxyType::REMOTE, timeoutSeconds); | |||
| 345 | }, | |||
| 346 | py::arg("timeout") = kUnsetRpcTimeout, | |||
| 347 | py::call_guard<py::gil_scoped_release>(), | |||
| 348 | R"( | |||
| 349 | Create a helper proxy to easily launch a ``remote`` using | |||
| 350 | the owner of the RRef as the destination to run functions on | |||
| 351 | the object referenced by this RRef. More specifically, | |||
| 352 | ``rref.remote().func_name(*args, **kwargs)`` is the same as | |||
| 353 | the following: | |||
| 354 | ||||
| 355 | >>> def run(rref, func_name, args, kwargs): | |||
| 356 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) | |||
| 357 | >>> | |||
| 358 | >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs)) | |||
| 359 | ||||
| 360 | Args: | |||
| 361 | timeout (float, optional): Timeout for ``rref.remote()``. If | |||
| 362 | the creation of this :class:`~torch.distributed.rpc.RRef` | |||
| 363 | is not successfully completed within the timeout, then the | |||
| 364 | next time there is an attempt to use the RRef | |||
| 365 | (such as ``to_here``), a timeout will be raised. If not | |||
| 366 | provided, the default RPC timeout will be used. Please see | |||
| 367 | ``rpc.remote()`` for specific timeout semantics for | |||
| 368 | :class:`~torch.distributed.rpc.RRef`. | |||
| 369 | ||||
| 370 | Example:: | |||
| 371 | >>> from torch.distributed import rpc | |||
| 372 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) | |||
| 373 | >>> rref.remote().size().to_here() # returns torch.Size([2, 2]) | |||
| 374 | >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]]) | |||
| 375 | )") | |||
| 376 | .def( | |||
| 377 | py::pickle( | |||
| 378 | /* __getstate__ */ | |||
| 379 | [](const PyRRef& /* unused */) { | |||
| 380 | TORCH_CHECK(if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
| 381 | false,if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
| 382 | "Can not pickle rref in python pickler, rref can only be "if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); } | |||
| 383 | "pickled when using RPC")if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(383), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"))); }; | |||
| 384 | // Note that this return has no meaning since we always | |||
| 385 | // throw, it's only here to satisfy Pybind API's | |||
| 386 | // requirement. | |||
| 387 | return py::make_tuple(); | |||
| 388 | }, | |||
| 389 | /* __setstate__ */ | |||
| 390 | [](py::tuple /* unused */) { // NOLINT | |||
| 391 | TORCH_CHECK(if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
| 392 | false,if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
| 393 | "Can not unpickle rref in python pickler, rref can only be "if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); } | |||
| 394 | "unpickled when using RPC")if ((__builtin_expect(static_cast<bool>(!(false)), 0))) { ::c10::detail::torchCheckFail( __func__, "../torch/csrc/distributed/rpc/init.cpp" , static_cast<uint32_t>(394), (::c10::detail::torchCheckMsgImpl ( "Expected " "false" " to be true, but got false. " "(Could this error message be improved? If so, " "please report an enhancement request to PyTorch.)", "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"))); }; | |||
| 395 | // Note that this return has no meaning since we always | |||
| 396 | // throw, it's only here to satisfy PyBind's API | |||
| 397 | // requirement. | |||
| 398 | return PyRRef( | |||
| 399 | py::cast<py::none>(Py_None(&_Py_NoneStruct)), | |||
| 400 | py::cast<py::none>(Py_None(&_Py_NoneStruct))); | |||
| 401 | }), | |||
| 402 | py::call_guard<py::gil_scoped_release>()) | |||
| 403 | .def( | |||
| 404 | "_serialize", | |||
| 405 | &PyRRef::pickle, | |||
| 406 | py::call_guard<py::gil_scoped_release>()) | |||
| 407 | .def_static( | |||
| 408 | "_deserialize", | |||
| 409 | &PyRRef::unpickle, | |||
| 410 | py::call_guard<py::gil_scoped_release>()) | |||
| 411 | .def( | |||
| 412 | "_get_type", | |||
| 413 | // Intentionally not releasing GIL, as most accesses just | |||
| 414 | // retrieve cached type py::object | |||
| 415 | &PyRRef::getRRefType, | |||
| 416 | py::arg("timeout") = kUnsetRpcTimeout, | |||
| 417 | py::arg("blocking") = true, | |||
| 418 | R"( | |||
| 419 | If ``blocking=True``, returns the type of the data object | |||
| 420 | referenced by this ``RRef``. On the owner, this is same as | |||
| 421 | ``type(rref.local_value())``. Otherwise, returns a future to | |||
| 422 | this result. On a user, this will trigger an RPC to fetch the | |||
| 423 | ``type`` object from the owner. After this function is run | |||
| 424 | once, the ``type`` object is cached by the ``RRef``, and | |||
| 425 | subsequent invocations no longer trigger RPC. Note that this is | |||
| 426 | true regardless of the ``blocking`` argument of subsequent | |||
| 427 | calls. | |||
| 428 | ||||
| 429 | Args: | |||
| 430 | rref (torch.distributed.rpc.RRef): The RRef to get type of. | |||
| 431 | timeout (float, optional): Timeout, in seconds for | |||
| 432 | ``_get_type``. If the call does not complete within | |||
| 433 | this timeframe, an exception indicating so will be | |||
| 434 | raised. If this argument is not provided, the default | |||
| 435 | RPC timeout will be used. | |||
| 436 | blocking (bool, optional): Whether to synchronously wait on | |||
| 437 | the RPC triggered by the first call and return the | |||
| 438 | type. If ``False``, will return a future. Default is | |||
| 439 | ``True``. | |||
| 440 | )") | |||
| 441 | .def( | |||
| 442 | "_get_future", | |||
| 443 | [](const PyRRef& self) { | |||
| 444 | return std::make_shared<jit::PythonFutureWrapper>( | |||
| 445 | self.getFuture()); | |||
| 446 | }, | |||
| 447 | py::call_guard<py::gil_scoped_release>(), | |||
| 448 | R"( | |||
| 449 | Returns the future that corresponds to the creation of this RRef | |||
| 450 | on the remote node. This is for internal use cases such as profiling | |||
| 451 | only. | |||
| 452 | )") | |||
| 453 | .def( | |||
| 454 | "_get_profiling_future", | |||
| 455 | [](const PyRRef& self) { | |||
| 456 | return std::make_shared<jit::PythonFutureWrapper>( | |||
| 457 | self.getProfilingFuture()); | |||
| 458 | }, | |||
| 459 | py::call_guard<py::gil_scoped_acquire>(), | |||
| 460 | R"( | |||
| 461 | Returns future that completes when the profiling event corresponding | |||
| 462 | to the creation of this RRef on the remote node has been recorded. | |||
| 463 | )") | |||
| 464 | .def( | |||
| 465 | "_set_profiling_future", | |||
| 466 | [](PyRRef& self, | |||
| 467 | const std::shared_ptr<jit::PythonFutureWrapper>& | |||
| 468 | wrappedFuture) { | |||
| 469 | self.setProfilingFuture(wrappedFuture->fut); | |||
| 470 | }, | |||
| 471 | py::call_guard<py::gil_scoped_acquire>(), | |||
| 472 | R"( | |||
| 473 | Set future that is completed when the profiling event corresponding | |||
| 474 | to the creation of this RRef on the remote node has been recorded. | |||
| 475 | )") | |||
| 476 | .def( | |||
| 477 | "backward", | |||
| 478 | [](PyRRef& self, | |||
| 479 | int64_t dist_autograd_ctx_id, | |||
| 480 | bool retain_graph) { | |||
| 481 | self.backward(dist_autograd_ctx_id, retain_graph); | |||
| 482 | }, | |||
| 483 | py::arg("dist_autograd_ctx_id") = -1, | |||
| 484 | py::arg("retain_graph") = false, | |||
| 485 | py::call_guard<py::gil_scoped_release>(), | |||
| 486 | R"( | |||
| 487 | Runs the backward pass using the RRef as the root of the | |||
| 488 | backward pass. If ``dist_autograd_ctx_id`` is provided, | |||
| 489 | we perform a distributed backward pass using the provided | |||
| 490 | ctx_id starting from the owner of the RRef. In this case, | |||
| 491 | :meth:`~torch.distributed.autograd.get_gradients` should be | |||
| 492 | used to retrieve the gradients. If ``dist_autograd_ctx_id`` | |||
| 493 | is ``None``, it is assumed that this is a local autograd graph | |||
| 494 | and we only perform a local backward pass. In the local case, | |||
| 495 | the node calling this API has to be the owner of the RRef. | |||
| 496 | The value of the RRef is expected to be a scalar Tensor. | |||
| 497 | ||||
| 498 | Args: | |||
| 499 | dist_autograd_ctx_id (int, optional): The distributed | |||
| 500 | autograd context id for which we should retrieve the | |||
| 501 | gradients (default: -1). | |||
| 502 | retain_graph(bool, optional): If ``False``, the graph used to | |||
| 503 | compute the grad will be freed. Note that in nearly all | |||
| 504 | cases setting this option to ``True`` is not needed and | |||
| 505 | often can be worked around in a much more efficient way. | |||
| 506 | Usually, you need to set this to ``True`` to run backward | |||
| 507 | multiple times (default: False). | |||
| 508 | ||||
| 509 | Example:: | |||
| 510 | >>> import torch.distributed.autograd as dist_autograd | |||
| 511 | >>> with dist_autograd.context() as context_id: | |||
| 512 | >>> rref.backward(context_id) | |||
| 513 | )") | |||
| 514 | // not releasing GIL to avoid context switch | |||
| 515 | .def("__repr__", &PyRRef::str); | |||
| 516 | ||||
| 517 | shared_ptr_class_<ProcessGroupRpcBackendOptions>( | |||
| 518 | module, | |||
| 519 | "ProcessGroupRpcBackendOptions", | |||
| 520 | rpcBackendOptions, | |||
| 521 | R"( | |||
| 522 | The backend options class for ``ProcessGroupAgent``, which is derived | |||
| 523 | from ``RpcBackendOptions``. | |||
| 524 | ||||
| 525 | Args: | |||
| 526 | num_send_recv_threads (int, optional): The number of threads in | |||
| 527 | the thread-pool used by ``ProcessGroupAgent`` (default: 4). | |||
| 528 | rpc_timeout (float, optional): The default timeout, in seconds, | |||
| 529 | for RPC requests (default: 60 seconds). If the | |||
| 530 | RPC has not completed in this timeframe, an exception | |||
| 531 | indicating so will be raised. Callers can override this | |||
| 532 | timeout for individual RPCs in | |||
| 533 | :meth:`~torch.distributed.rpc.rpc_sync` and | |||
| 534 | :meth:`~torch.distributed.rpc.rpc_async` if necessary. | |||
| 535 | init_method (str, optional): The URL to initialize | |||
| 536 | ``ProcessGroupGloo`` (default: ``env://``). | |||
| 537 | )") | |||
| 538 | .def( | |||
| 539 | py::init<int, float, std::string>(), | |||
| 540 | py::arg("num_send_recv_threads") = kDefaultNumSendRecvThreads, | |||
| 541 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
| 542 | py::arg("init_method") = kDefaultInitMethod) | |||
| 543 | .def_readwrite( | |||
| 544 | "num_send_recv_threads", | |||
| 545 | &ProcessGroupRpcBackendOptions::numSendRecvThreads, | |||
| 546 | R"( | |||
| 547 | The number of threads in the thread-pool used by ProcessGroupAgent. | |||
| 548 | )"); | |||
| 549 | ||||
| 550 | module.attr("_DEFAULT_NUM_SEND_RECV_THREADS") = | |||
| 551 | py::cast(kDefaultNumSendRecvThreads); | |||
| 552 | ||||
| 553 | shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent) | |||
| 554 | .def(py::init([](const c10::intrusive_ptr<::c10d::Store>& store, | |||
| 555 | std::string workerName, | |||
| 556 | const c10::intrusive_ptr<::c10d::ProcessGroup>& pg, | |||
| 557 | int numSendRecvThreads, | |||
| 558 | std::chrono::milliseconds rpcTimeout) { | |||
| 559 | return std::shared_ptr<ProcessGroupAgent>( | |||
| 560 | new ProcessGroupAgent( | |||
| 561 | store, | |||
| 562 | std::move(workerName), | |||
| 563 | pg, | |||
| 564 | numSendRecvThreads, | |||
| 565 | rpcTimeout, | |||
| 566 | std::make_unique<RequestCallbackImpl>()), | |||
| 567 | impl::destroy_without_gil<ProcessGroupAgent>); | |||
| 568 | })) | |||
| 569 | .def( | |||
| 570 | "get_worker_info", | |||
| 571 | (const WorkerInfo& (ProcessGroupAgent::*)(void) const) & | |||
| 572 | RpcAgent::getWorkerInfo, | |||
| 573 | py::call_guard<py::gil_scoped_release>()) | |||
| 574 | .def( | |||
| 575 | "get_worker_info", | |||
| 576 | (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) & | |||
| 577 | ProcessGroupAgent::getWorkerInfo, | |||
| 578 | py::call_guard<py::gil_scoped_release>()) | |||
| 579 | .def( | |||
| 580 | "get_worker_info", | |||
| 581 | (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) & | |||
| 582 | ProcessGroupAgent::getWorkerInfo, | |||
| 583 | py::call_guard<py::gil_scoped_release>()) | |||
| 584 | .def( | |||
| 585 | "get_worker_infos", | |||
| 586 | (std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) & | |||
| 587 | ProcessGroupAgent::getWorkerInfos, | |||
| 588 | py::call_guard<py::gil_scoped_release>()) | |||
| 589 | .def( | |||
| 590 | "_get_device_map", | |||
| 591 | (DeviceMap(ProcessGroupAgent::*)(const WorkerInfo& dst) const) & | |||
| 592 | ProcessGroupAgent::getDeviceMap, | |||
| 593 | py::call_guard<py::gil_scoped_release>()) | |||
| 594 | .def( | |||
| 595 | "join", | |||
| 596 | &ProcessGroupAgent::join, | |||
| 597 | py::call_guard<py::gil_scoped_release>(), | |||
| 598 | py::arg("shutdown") = false) | |||
| 599 | .def( | |||
| 600 | "shutdown", | |||
| 601 | &ProcessGroupAgent::shutdown, | |||
| 602 | py::call_guard<py::gil_scoped_release>()) | |||
| 603 | .def( | |||
| 604 | "sync", | |||
| 605 | &ProcessGroupAgent::sync, | |||
| 606 | py::call_guard<py::gil_scoped_release>()); | |||
| 607 | ||||
| 608 | #ifdef USE_TENSORPIPE1 | |||
| 609 | ||||
| 610 | // Base class: torch.distributed.rpc.RpcBackendOptions. | |||
| 611 | py::class_<TensorPipeRpcBackendOptions>( | |||
| 612 | module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions) | |||
| 613 | .def( | |||
| 614 | py::init< | |||
| 615 | int, | |||
| 616 | optional<std::vector<std::string>>, | |||
| 617 | optional<std::vector<std::string>>, | |||
| 618 | float, | |||
| 619 | std::string, | |||
| 620 | std::unordered_map<std::string, DeviceMap>, | |||
| 621 | std::vector<c10::Device>>(), | |||
| 622 | py::arg("num_worker_threads") = kDefaultNumWorkerThreads, | |||
| 623 | py::arg("_transports") = optional<std::vector<std::string>>(), | |||
| 624 | py::arg("_channels") = optional<std::vector<std::string>>(), | |||
| 625 | py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds, | |||
| 626 | py::arg("init_method") = kDefaultInitMethod, | |||
| 627 | py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(), | |||
| 628 | py::arg("devices") = std::vector<c10::Device>()) | |||
| 629 | .def_readwrite( | |||
| 630 | "num_worker_threads", | |||
| 631 | &TensorPipeRpcBackendOptions::numWorkerThreads, | |||
| 632 | R"( | |||
| 633 | The number of threads in the thread-pool used by | |||
| 634 | :class:`~torch.distributed.rpc.TensorPipeAgent` to execute | |||
| 635 | requests. | |||
| 636 | )") | |||
| 637 | .def_readwrite( | |||
| 638 | "device_maps", | |||
| 639 | &TensorPipeRpcBackendOptions::deviceMaps, | |||
| 640 | R"(The device map locations.)") | |||
| 641 | .def_readwrite( | |||
| 642 | "devices", | |||
| 643 | &TensorPipeRpcBackendOptions::devices, | |||
| 644 | R"(All devices used by the local agent.)") | |||
| 645 | .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap); | |||
| 646 | ||||
| 647 | module.attr("_DEFAULT_NUM_WORKER_THREADS") = | |||
| 648 | py::cast(kDefaultNumWorkerThreads); | |||
| 649 | ||||
| 650 | shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent) | |||
| 651 | .def( | |||
| 652 | py::init( | |||
| 653 | [](const c10::intrusive_ptr<::c10d::Store>& store, | |||
| 654 | std::string selfName, | |||
| 655 | worker_id_t selfId, | |||
| 656 | int worldSize, | |||
| 657 | c10::intrusive_ptr<::c10d::ProcessGroup> processGroup, | |||
| 658 | TensorPipeRpcBackendOptions opts, | |||
| 659 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, | |||
| 660 | std::vector<c10::Device> devices) { | |||
| 661 | return std::shared_ptr<TensorPipeAgent>( | |||
| 662 | new TensorPipeAgent( | |||
| 663 | store, | |||
| 664 | std::move(selfName), | |||
| 665 | selfId, | |||
| 666 | worldSize, | |||
| 667 | std::move(processGroup), | |||
| 668 | std::move(opts), | |||
| 669 | std::move(reverseDeviceMaps), | |||
| 670 | std::move(devices), | |||
| 671 | std::make_unique<RequestCallbackImpl>()), | |||
| 672 | impl::destroy_without_gil<TensorPipeAgent>); | |||
| 673 | }), | |||
| 674 | py::arg("store"), | |||
| 675 | py::arg("name"), | |||
| 676 | py::arg("rank"), | |||
| 677 | py::arg("world_size"), | |||
| 678 | py::arg("process_group"), | |||
| 679 | py::arg("rpc_backend_options"), | |||
| 680 | py::arg("reverse_device_maps"), | |||
| 681 | py::arg("devices")) | |||
| 682 | .def( | |||
| 683 | "join", | |||
| 684 | &TensorPipeAgent::join, | |||
| 685 | py::call_guard<py::gil_scoped_release>(), | |||
| 686 | py::arg("shutdown") = false) | |||
| 687 | .def( | |||
| 688 | "shutdown", | |||
| 689 | &TensorPipeAgent::shutdown, | |||
| 690 | py::call_guard<py::gil_scoped_release>()) | |||
| 691 | .def( | |||
| 692 | "get_worker_info", | |||
| 693 | (const WorkerInfo& (TensorPipeAgent::*)(void) const) & | |||
| 694 | RpcAgent::getWorkerInfo, | |||
| 695 | py::call_guard<py::gil_scoped_release>()) | |||
| 696 | .def( | |||
| 697 | "get_worker_info", | |||
| 698 | (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) & | |||
| 699 | TensorPipeAgent::getWorkerInfo, | |||
| 700 | py::call_guard<py::gil_scoped_release>()) | |||
| 701 | .def( | |||
| 702 | "get_worker_info", | |||
| 703 | (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) & | |||
| 704 | TensorPipeAgent::getWorkerInfo, | |||
| 705 | py::call_guard<py::gil_scoped_release>()) | |||
| 706 | .def( | |||
| 707 | "get_worker_infos", | |||
| 708 | (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) & | |||
| 709 | TensorPipeAgent::getWorkerInfos, | |||
| 710 | py::call_guard<py::gil_scoped_release>()) | |||
| 711 | .def( | |||
| 712 | "_get_device_map", | |||
| 713 | (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) & | |||
| 714 | TensorPipeAgent::getDeviceMap, | |||
| 715 | py::call_guard<py::gil_scoped_release>()); | |||
| 716 | ||||
| 717 | #endif // USE_TENSORPIPE | |||
| 718 | ||||
| 719 | module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet); | |||
| 720 | ||||
| 721 | module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent); | |||
| 722 | ||||
| 723 | module.def( | |||
| 724 | "_set_and_start_rpc_agent", | |||
| 725 | [](const std::shared_ptr<RpcAgent>& rpcAgent) { | |||
| 726 | RpcAgent::setCurrentRpcAgent(rpcAgent); | |||
| 727 | // Initializing typeResolver inside RpcAgent constructor will make | |||
| 728 | // RpcAgent have python dependency. To avoid RpcAgent to have python | |||
| 729 | // dependency, setTypeResolver() here. | |||
| 730 | std::shared_ptr<TypeResolver> typeResolver = | |||
| 731 | std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { | |||
| 732 | auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr( | |||
| 733 | qn.qualifiedName()); | |||
| 734 | return c10::StrongTypePtr( | |||
| 735 | PythonRpcHandler::getInstance().jitCompilationUnit(), | |||
| 736 | std::move(typePtr)); | |||
| 737 | }); | |||
| 738 | rpcAgent->setTypeResolver(typeResolver); | |||
| 739 | rpcAgent->start(); | |||
| 740 | }, | |||
| 741 | py::call_guard<py::gil_scoped_release>()); | |||
| 742 | ||||
| 743 | module.def("_reset_current_rpc_agent", []() { | |||
| 744 | RpcAgent::setCurrentRpcAgent(nullptr); | |||
| 745 | }); | |||
| 746 | ||||
| 747 | module.def( | |||
| 748 | "_delete_all_user_and_unforked_owner_rrefs", | |||
| 749 | [](std::chrono::milliseconds timeoutMillis) { | |||
| 750 | RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis); | |||
| 751 | }, | |||
| 752 | py::arg("timeout") = kDeleteAllUsersTimeout, | |||
| 753 | py::call_guard<py::gil_scoped_release>()); | |||
| 754 | ||||
| 755 | module.def("_destroy_rref_context", [](bool ignoreRRefLeak) { | |||
| 756 | // NB: do not release GIL in the function. The destroyInstance() method | |||
| 757 | // returns a list of deleted OwnerRRefs that hold py::object instances. | |||
| 758 | // Clearing those OwnerRRefs are likely to trigger Python deref, which | |||
| 759 | // requires GIL. | |||
| 760 | RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear(); | |||
| 761 | }); | |||
| 762 | ||||
| 763 | module.def("_rref_context_get_debug_info", []() { | |||
| 764 | return RRefContext::getInstance().getDebugInfo(); | |||
| 765 | }); | |||
| 766 | ||||
| 767 | module.def( | |||
| 768 | "_cleanup_python_rpc_handler", | |||
| 769 | []() { PythonRpcHandler::getInstance().cleanup(); }, | |||
| 770 | py::call_guard<py::gil_scoped_release>()); | |||
| 771 | ||||
| 772 | module.def( | |||
| 773 | "_invoke_rpc_builtin", | |||
| 774 | [](const WorkerInfo& dst, | |||
| 775 | const std::string& opName, | |||
| 776 | const float rpcTimeoutSeconds, | |||
| 777 | const py::args& args, | |||
| 778 | const py::kwargs& kwargs) { | |||
| 779 | return std::make_shared<jit::PythonFutureWrapper>( | |||
| 780 | pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); | |||
| 781 | }, | |||
| 782 | py::call_guard<py::gil_scoped_acquire>()); | |||
| 783 | ||||
| 784 | module.def( | |||
| 785 | "_invoke_rpc_python_udf", | |||
| 786 | [](const WorkerInfo& dst, | |||
| 787 | std::string& pickledPythonUDF, | |||
| 788 | std::vector<torch::Tensor>& tensors, | |||
| 789 | const float rpcTimeoutSeconds, | |||
| 790 | const bool isAsyncExecution) { | |||
| 791 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf( | |||
| 792 | dst, | |||
| 793 | pickledPythonUDF, | |||
| 794 | tensors, | |||
| 795 | rpcTimeoutSeconds, | |||
| 796 | isAsyncExecution)); | |||
| 797 | }, | |||
| 798 | py::call_guard<py::gil_scoped_release>()); | |||
| 799 | ||||
| 800 | module.def( | |||
| 801 | "_invoke_rpc_torchscript", | |||
| 802 | [](const std::string& dstWorkerName, | |||
| 803 | const std::string& qualifiedNameStr, | |||
| 804 | const py::tuple& argsTuple, | |||
| 805 | const py::dict& kwargsDict, | |||
| 806 | const float rpcTimeoutSeconds, | |||
| 807 | const bool isAsyncExecution) { | |||
| 808 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript( | |||
| 809 | dstWorkerName, | |||
| 810 | qualifiedNameStr, | |||
| 811 | argsTuple, | |||
| 812 | kwargsDict, | |||
| 813 | rpcTimeoutSeconds, | |||
| 814 | isAsyncExecution)); | |||
| 815 | }, | |||
| 816 | py::call_guard<py::gil_scoped_release>()); | |||
| 817 | ||||
| 818 | module.def( | |||
| 819 | "_invoke_remote_builtin", | |||
| 820 | &pyRemoteBuiltin, | |||
| 821 | py::call_guard<py::gil_scoped_acquire>()); | |||
| 822 | ||||
| 823 | module.def( | |||
| 824 | "_invoke_remote_python_udf", | |||
| 825 | &pyRemotePythonUdf, | |||
| 826 | py::call_guard<py::gil_scoped_release>()); | |||
| 827 | ||||
| 828 | module.def( | |||
| 829 | "_invoke_remote_torchscript", | |||
| 830 | &pyRemoteTorchscript, | |||
| 831 | py::call_guard<py::gil_scoped_release>()); | |||
| 832 | ||||
| 833 | module.def( | |||
| 834 | "get_rpc_timeout", | |||
| 835 | []() { | |||
| 836 | return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() / | |||
| 837 | kSecToMsConversion; | |||
| 838 | }, | |||
| 839 | R"( | |||
| 840 | Retrieve the default timeout for all RPCs that was set during RPC initialization. | |||
| 841 | The returned value will be in seconds. | |||
| 842 | Returns: | |||
| 843 | ``float`` indicating the RPC timeout in seconds. | |||
| 844 | )"); | |||
| 845 | ||||
| 846 | module.def( | |||
| 847 | "enable_gil_profiling", | |||
| 848 | [](bool flag) { | |||
| 849 | RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag); | |||
| 850 | }, | |||
| 851 | R"( | |||
| 852 | Set whether GIL wait times should be enabled or not. This incurs a slight | |||
| 853 | overhead cost. Default is disabled for performance reasons. | |||
| 854 | ||||
| 855 | Args: | |||
| 856 | flag (bool): True to set GIL profiling, False to disable. | |||
| 857 | )"); | |||
| 858 | ||||
| 859 | module.def( | |||
| 860 | "_set_rpc_timeout", | |||
| 861 | [](const float rpcTimeoutSeconds) { | |||
| 862 | auto rpcTimeout = std::chrono::milliseconds( | |||
| 863 | static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion)); | |||
| 864 | RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout); | |||
| 865 | }, | |||
| 866 | R"( | |||
| 867 | Set the default timeout for all RPCs. The input unit is expected to be | |||
| 868 | in seconds. If an RPC is not completed within this time, an exception | |||
| 869 | indicating it has timed out will be raised. To control timeout for | |||
| 870 | specific RPCs, a timeout parameter can be passed into | |||
| 871 | :meth:`~torch.distributed.rpc.rpc_sync` and | |||
| 872 | :meth:`~torch.distributed.rpc.rpc_async`. | |||
| 873 | ||||
| 874 | Args: | |||
| 875 | rpcTimeoutSeconds (float): Timeout value in seconds. | |||
| 876 | )"); | |||
| 877 | ||||
| 878 | module.def( | |||
| 879 | "_enable_server_process_global_profiler", | |||
| 880 | &profiler::processglobal::enableServer); | |||
| 881 | module.def( | |||
| 882 | "_disable_server_process_global_profiler", | |||
| 883 | &profiler::processglobal::disableServer); | |||
| 884 | ||||
| 885 | module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId); | |||
| 886 | ||||
| 887 | py::class_< | |||
| 888 | RemoteProfilerManager, | |||
| 889 | std::unique_ptr<RemoteProfilerManager, py::nodelete>>( | |||
| 890 | module, "RemoteProfilerManager") | |||
| 891 | .def("set_current_profiling_key", [](const std::string& key) { | |||
| 892 | auto& inst = RemoteProfilerManager::getInstance(); | |||
| 893 | inst.setCurrentKey(key); | |||
| 894 | }); | |||
| 895 | ||||
| 896 | module.def( | |||
| 897 | "_enable_jit_rref_pickle", | |||
| 898 | &enableJitRRefPickle, | |||
| 899 | R"( | |||
| 900 | Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with | |||
| 901 | pickled RRefs out of RPC contexts. | |||
| 902 | ||||
| 903 | ||||
| 904 | .. warning:: | |||
| 905 | This is dangerous. If the module contains RRefs, the pickled | |||
| 906 | result must be sent over RPC and get unpickled on the receiving side | |||
| 907 | to restore the module. Otherwise, there will be RRef leaks, which | |||
| 908 | can potentially lead to program hang. When using this API, it is | |||
| 909 | applications responsibility to make sure that the above assumption | |||
| 910 | always holds. | |||
| 911 | )"); | |||
| 912 | module.def("_disable_jit_rref_pickle", &disableJitRRefPickle); | |||
| 913 | ||||
| 914 | Py_RETURN_TRUEreturn _Py_INCREF(((PyObject*)(((PyObject *) &_Py_TrueStruct )))), ((PyObject *) &_Py_TrueStruct); | |||
| 915 | } | |||
| 916 | ||||
| 917 | } // namespace | |||
| 918 | ||||
| 919 | static PyMethodDef methods[] = { // NOLINT | |||
| 920 | {"_rpc_init", rpc_init, METH_NOARGS0x0004, nullptr}, | |||
| 921 | {nullptr, nullptr, 0, nullptr}}; | |||
| 922 | ||||
| 923 | PyMethodDef* python_functions() { | |||
| 924 | return methods; | |||
| 925 | } | |||
| 926 | ||||
| 927 | } // namespace rpc | |||
| 928 | } // namespace distributed | |||
| 929 | } // namespace torch |
| 1 | #ifndef PyImport_ImportModule |
| 2 | struct _object; |
| 3 | typedef struct _object PyObject; |
| 4 | PyObject* clang_analyzer_PyObject_New_Reference(); |
| 5 | PyObject* PyImport_ImportModule(const char *name) { |
| 6 | return clang_analyzer_PyObject_New_Reference(); |
| 7 | } |
| 8 | #else |
| 9 | #warning "API PyImport_ImportModule is defined as a macro." |
| 10 | #endif |