Skip to content

Commit edab7a4

Browse files
zhaojuanmaoThiago Crepaldi
authored and
Thiago Crepaldi
committed
fix python rpc handler exit crash (pytorch#27251)
Summary: Pull Request resolved: pytorch#27251 Explicitly clean up py::objects to avoid segment faults when py::objects with CPython are cleaned up later at program exit. See similar issues reported pybind/pybind11#1598 and pybind/pybind11#1493. Our local tests also caught this segment faults if py::objects are cleaned up at program exit. The explaination is: CPython cleans up most critical utitlies before cleaning up PythonRpcHandler singleton, so when PythonRpcHandler signleton cleans up py::objects and call dec_ref(), it will crash. The solution is to clean up py::objects earlier when Rpc agent join(). Be note that py::objects can not be cleaned up when Rpc agent is destroyed as well, as Rpc agent is global variable and it will have same issue as PythonRpcHandler. close pytorch#27182 ghstack-source-id: 92035069 Test Plan: unit tests on python 3.6 and python 3.5 Differential Revision: D17727362 fbshipit-source-id: c254023f6a85acce35528ba756a4efabba9a519f
1 parent 68871d1 commit edab7a4

File tree

6 files changed

+45
-13
lines changed

6 files changed

+45
-13
lines changed

test/run_test.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.utils import cpp_extension
1717
from common_utils import TEST_WITH_ROCM, shell
1818
import torch.distributed as dist
19+
PY33 = sys.version_info >= (3, 3)
1920
PY36 = sys.version_info >= (3, 6)
2021

2122
TESTS = [
@@ -61,22 +62,22 @@
6162
'function_schema',
6263
]
6364

64-
# skip < 3.6 b/c fstrings added in 3.6 for jit_py3
65-
# skip < 3.6 for rpc_spawn and dist_autograd_spawn temporarily because
66-
# a segmenation fault was triggered on python 3.5,
67-
# rpc_spawn and dist_autograd_spawn tests were added in
68-
# https://github.com/pytorch/pytorch/pull/25656
69-
# skip < 3.6 for rpc_fork as it imports mock that is only available in 3.6, mock
70-
# was added to rpc_fork in https://github.com/pytorch/pytorch/pull/26997
71-
if PY36:
65+
# skip < 3.3 because mock is added in 3.3 and is used in rpc_fork and rpc_spawn
66+
# skip python2 for rpc and dist_autograd tests that do not support python2
67+
if PY33:
7268
TESTS.extend([
73-
'jit_py3',
7469
'rpc_fork',
7570
'rpc_spawn',
7671
'dist_autograd_fork',
7772
'dist_autograd_spawn',
7873
])
7974

75+
# skip < 3.6 b/c fstrings added in 3.6
76+
if PY36:
77+
TESTS.extend([
78+
'jit_py3',
79+
])
80+
8081
WINDOWS_BLACKLIST = [
8182
'distributed',
8283
'rpc_fork',

torch/csrc/distributed/rpc/init.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ PyObject* rpc_init(PyObject* /* unused */) {
7171
return PyRRef::unpickle(t);
7272
}));
7373

74+
// future.wait() should not be called after join_rpc(), e.g., pythonRpcHandler
75+
// is cleaned up in join_rpc(), after join_rpc(), python objects returned
76+
// from rpc python call can not be resolved.
7477
auto futureMessage =
7578
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
7679
.def(
@@ -112,6 +115,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
112115
RRefContext::getInstance().destroyInstance();
113116
});
114117

118+
module.def("_cleanup_python_rpc_handler", []() {
119+
PythonRpcHandler::getInstance().cleanup();
120+
});
121+
115122
module.def(
116123
"invoke_rpc_builtin",
117124
[](RpcAgent& agent,

torch/csrc/distributed/rpc/process_group_agent.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,6 @@ ProcessGroupAgent::ProcessGroupAgent(
147147
for (int rank = 0; rank < (int)tmpWorkerIds.size(); ++rank) {
148148
allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
149149
}
150-
151-
// construct PythonRpcHandler singleton here
152-
PythonRpcHandler::getInstance();
153150
}
154151

155152
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(

torch/csrc/distributed/rpc/python_rpc_handler.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ PythonRpcHandler::PythonRpcHandler() {
2626
pySerialize_ = getFunction(module, "serialize");
2727
}
2828

29+
void PythonRpcHandler::cleanup() {
30+
AutoGIL ag;
31+
pyRunFunction_ = py::none();
32+
pyLoadReturnValue_ = py::none();
33+
pySerialize_ = py::none();
34+
}
35+
2936
PythonRpcHandler& PythonRpcHandler::getInstance() {
3037
static PythonRpcHandler handler;
3138
return handler;

torch/csrc/distributed/rpc/python_rpc_handler.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ class PYBIND11_EXPORT PythonRpcHandler {
3939
// Deserialize a string into a py::object
4040
py::object deserialize(const SerializedPyObj& serializedObj);
4141

42+
// Explicitly clean up py::objects to avoid segment faults when
43+
// py::objects with CPython are cleaned up later at program exit
44+
// See similar issues reported https://github.com/pybind/pybind11/issues/1598
45+
// and https://github.com/pybind/pybind11/issues/1493
46+
// Our local tests also caught this segment faults if py::objects are cleaned
47+
// up at program exit. The explaination is: CPython cleans up most critical
48+
// utilities before cleaning up PythonRpcHandler singleton, so when
49+
// PythonRpcHandler signleton cleans up py::objects and call dec_ref(), it
50+
// will crash.
51+
// The solution is to clean up py::objects earlier when Rpc agent join().
52+
// Be note that py::objects can not be cleaned up when Rpc agent is destroyed
53+
// as well, as Rpc agent is global variable and it will have same issue as
54+
// PythonRpcHandler.
55+
void cleanup();
56+
4257
private:
4358
PythonRpcHandler();
4459
~PythonRpcHandler() = default;

torch/distributed/rpc/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch.distributed import invoke_rpc_builtin, invoke_rpc_python_udf
22
from torch.distributed import invoke_remote_builtin, invoke_remote_python_udf
33
from torch.distributed import _start_rpc_agent
4-
from torch.distributed import _destroy_rref_context
4+
from torch.distributed import _destroy_rref_context, _cleanup_python_rpc_handler
55
from torch.distributed import ProcessGroupAgent
66
from torch.distributed import WorkerInfo
77
from .backend_registry import is_backend_registered, init_backend
@@ -40,6 +40,11 @@ def join_rpc():
4040
_agent.join()
4141
_agent = None
4242
_destroy_rref_context()
43+
# clean up python rpc handler in join_rpc(), see comments in
44+
# PythonRpcHandler::cleanup(), call it in python API because the
45+
# cleanup() function has python dependency, it assumes python
46+
# interpreter exists
47+
_cleanup_python_rpc_handler()
4348

4449

4550
@_require_initialized

0 commit comments

Comments
 (0)