File tree 6 files changed +45
-13
lines changed
6 files changed +45
-13
lines changed Original file line number Diff line number Diff line change 16
16
from torch .utils import cpp_extension
17
17
from common_utils import TEST_WITH_ROCM , shell
18
18
import torch .distributed as dist
19
+ PY33 = sys .version_info >= (3 , 3 )
19
20
PY36 = sys .version_info >= (3 , 6 )
20
21
21
22
TESTS = [
61
62
'function_schema' ,
62
63
]
63
64
64
- # skip < 3.6 b/c fstrings added in 3.6 for jit_py3
65
- # skip < 3.6 for rpc_spawn and dist_autograd_spawn temporarily because
66
- # a segmenation fault was triggered on python 3.5,
67
- # rpc_spawn and dist_autograd_spawn tests were added in
68
- # https://github.com/pytorch/pytorch/pull/25656
69
- # skip < 3.6 for rpc_fork as it imports mock that is only available in 3.6, mock
70
- # was added to rpc_fork in https://github.com/pytorch/pytorch/pull/26997
71
- if PY36 :
65
+ # skip < 3.3 because mock is added in 3.3 and is used in rpc_fork and rpc_spawn
66
+ # skip python2 for rpc and dist_autograd tests that do not support python2
67
+ if PY33 :
72
68
TESTS .extend ([
73
- 'jit_py3' ,
74
69
'rpc_fork' ,
75
70
'rpc_spawn' ,
76
71
'dist_autograd_fork' ,
77
72
'dist_autograd_spawn' ,
78
73
])
79
74
75
+ # skip < 3.6 b/c fstrings added in 3.6
76
+ if PY36 :
77
+ TESTS .extend ([
78
+ 'jit_py3' ,
79
+ ])
80
+
80
81
WINDOWS_BLACKLIST = [
81
82
'distributed' ,
82
83
'rpc_fork' ,
Original file line number Diff line number Diff line change @@ -71,6 +71,9 @@ PyObject* rpc_init(PyObject* /* unused */) {
71
71
return PyRRef::unpickle (t);
72
72
}));
73
73
74
+ // future.wait() should not be called after join_rpc(), e.g., pythonRpcHandler
75
+ // is cleaned up in join_rpc(), after join_rpc(), python objects returned
76
+ // from rpc python call can not be resolved.
74
77
auto futureMessage =
75
78
shared_ptr_class_<FutureMessage>(module, " FutureMessage" )
76
79
.def (
@@ -112,6 +115,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
112
115
RRefContext::getInstance ().destroyInstance ();
113
116
});
114
117
118
+ module.def (" _cleanup_python_rpc_handler" , []() {
119
+ PythonRpcHandler::getInstance ().cleanup ();
120
+ });
121
+
115
122
module.def (
116
123
" invoke_rpc_builtin" ,
117
124
[](RpcAgent& agent,
Original file line number Diff line number Diff line change @@ -147,9 +147,6 @@ ProcessGroupAgent::ProcessGroupAgent(
147
147
for (int rank = 0 ; rank < (int )tmpWorkerIds.size (); ++rank) {
148
148
allWorkerInfo_.emplace_back (std::move (tmpWorkerIds[rank]), rank);
149
149
}
150
-
151
- // construct PythonRpcHandler singleton here
152
- PythonRpcHandler::getInstance ();
153
150
}
154
151
155
152
const WorkerInfo& ProcessGroupAgent::getWorkerInfo (
Original file line number Diff line number Diff line change @@ -26,6 +26,13 @@ PythonRpcHandler::PythonRpcHandler() {
26
26
pySerialize_ = getFunction (module, " serialize" );
27
27
}
28
28
29
+ void PythonRpcHandler::cleanup () {
30
+ AutoGIL ag;
31
+ pyRunFunction_ = py::none ();
32
+ pyLoadReturnValue_ = py::none ();
33
+ pySerialize_ = py::none ();
34
+ }
35
+
29
36
PythonRpcHandler& PythonRpcHandler::getInstance () {
30
37
static PythonRpcHandler handler;
31
38
return handler;
Original file line number Diff line number Diff line change @@ -39,6 +39,21 @@ class PYBIND11_EXPORT PythonRpcHandler {
39
39
// Deserialize a string into a py::object
40
40
py::object deserialize (const SerializedPyObj& serializedObj);
41
41
42
+ // Explicitly clean up py::objects to avoid segment faults when
43
+ // py::objects with CPython are cleaned up later at program exit
44
+ // See similar issues reported https://github.com/pybind/pybind11/issues/1598
45
+ // and https://github.com/pybind/pybind11/issues/1493
46
+ // Our local tests also caught this segment faults if py::objects are cleaned
47
+ // up at program exit. The explaination is: CPython cleans up most critical
48
+ // utilities before cleaning up PythonRpcHandler singleton, so when
49
+ // PythonRpcHandler signleton cleans up py::objects and call dec_ref(), it
50
+ // will crash.
51
+ // The solution is to clean up py::objects earlier when Rpc agent join().
52
+ // Be note that py::objects can not be cleaned up when Rpc agent is destroyed
53
+ // as well, as Rpc agent is global variable and it will have same issue as
54
+ // PythonRpcHandler.
55
+ void cleanup ();
56
+
42
57
private:
43
58
PythonRpcHandler ();
44
59
~PythonRpcHandler () = default ;
Original file line number Diff line number Diff line change 1
1
from torch .distributed import invoke_rpc_builtin , invoke_rpc_python_udf
2
2
from torch .distributed import invoke_remote_builtin , invoke_remote_python_udf
3
3
from torch .distributed import _start_rpc_agent
4
- from torch .distributed import _destroy_rref_context
4
+ from torch .distributed import _destroy_rref_context , _cleanup_python_rpc_handler
5
5
from torch .distributed import ProcessGroupAgent
6
6
from torch .distributed import WorkerInfo
7
7
from .backend_registry import is_backend_registered , init_backend
@@ -40,6 +40,11 @@ def join_rpc():
40
40
_agent .join ()
41
41
_agent = None
42
42
_destroy_rref_context ()
43
+ # clean up python rpc handler in join_rpc(), see comments in
44
+ # PythonRpcHandler::cleanup(), call it in python API because the
45
+ # cleanup() function has python dependency, it assumes python
46
+ # interpreter exists
47
+ _cleanup_python_rpc_handler ()
43
48
44
49
45
50
@_require_initialized
You can’t perform that action at this time.
0 commit comments