@@ -51,6 +51,7 @@ def __init__(
51
51
if self .world_size == 1 :
52
52
self .available = False
53
53
self .disabled = True
54
+ self .stream = None
54
55
return
55
56
try :
56
57
self .nccl = NCCLLibrary (library_path )
@@ -59,6 +60,7 @@ def __init__(
59
60
# e.g. in a non-GPU environment
60
61
self .available = False
61
62
self .disabled = True
63
+ self .stream = None
62
64
return
63
65
64
66
self .available = True
@@ -96,12 +98,12 @@ def __init__(
96
98
with torch .cuda .device (device ):
97
99
self .comm : ncclComm_t = self .nccl .ncclCommInitRank (
98
100
self .world_size , self .unique_id , self .rank )
101
+ self .stream = torch .cuda .Stream ()
99
102
100
- stream = torch .cuda .current_stream ()
101
103
# A small all_reduce for warmup.
102
104
data = torch .zeros (1 , device = device )
103
105
self .all_reduce (data )
104
- stream .synchronize ()
106
+ self . stream .synchronize ()
105
107
del data
106
108
107
109
def all_reduce (self ,
@@ -120,7 +122,7 @@ def all_reduce(self,
120
122
out_tensor = torch .empty_like (in_tensor )
121
123
122
124
if stream is None :
123
- stream = torch . cuda . current_stream ()
125
+ stream = self . stream
124
126
self .nccl .ncclAllReduce (buffer_type (in_tensor .data_ptr ()),
125
127
buffer_type (out_tensor .data_ptr ()),
126
128
in_tensor .numel (),
@@ -142,7 +144,7 @@ def all_gather(self,
142
144
f"this nccl communicator is created to work on { self .device } , "
143
145
f"but the input tensor is on { input_tensor .device } " )
144
146
if stream is None :
145
- stream = torch . cuda . current_stream ()
147
+ stream = self . stream
146
148
self .nccl .ncclAllGather (
147
149
buffer_type (input_tensor .data_ptr ()),
148
150
buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
@@ -163,7 +165,7 @@ def reduce_scatter(self,
163
165
f"this nccl communicator is created to work on { self .device } , "
164
166
f"but the input tensor is on { input_tensor .device } " )
165
167
if stream is None :
166
- stream = torch . cuda . current_stream ()
168
+ stream = self . stream
167
169
self .nccl .ncclReduceScatter (
168
170
buffer_type (input_tensor .data_ptr ()),
169
171
buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
@@ -178,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
178
180
f"this nccl communicator is created to work on { self .device } , "
179
181
f"but the input tensor is on { tensor .device } " )
180
182
if stream is None :
181
- stream = torch . cuda . current_stream ()
183
+ stream = self . stream
182
184
self .nccl .ncclSend (buffer_type (tensor .data_ptr ()), tensor .numel (),
183
185
ncclDataTypeEnum .from_torch (tensor .dtype ), dst ,
184
186
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -190,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
190
192
f"this nccl communicator is created to work on { self .device } , "
191
193
f"but the input tensor is on { tensor .device } " )
192
194
if stream is None :
193
- stream = torch . cuda . current_stream ()
195
+ stream = self . stream
194
196
self .nccl .ncclRecv (buffer_type (tensor .data_ptr ()), tensor .numel (),
195
197
ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
196
198
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -202,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
202
204
f"this nccl communicator is created to work on { self .device } , "
203
205
f"but the input tensor is on { tensor .device } " )
204
206
if stream is None :
205
- stream = torch . cuda . current_stream ()
207
+ stream = self . stream
206
208
if src == self .rank :
207
209
sendbuff = buffer_type (tensor .data_ptr ())
208
210
# NCCL requires the sender also to have a receive buffer
@@ -215,17 +217,25 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
215
217
self .comm , cudaStream_t (stream .cuda_stream ))
216
218
217
219
@contextmanager
218
- def change_state (self , enable : Optional [bool ] = None ):
220
+ def change_state (self ,
221
+ enable : Optional [bool ] = None ,
222
+ stream : Optional [torch .cuda .Stream ] = None ):
219
223
"""
220
224
A context manager to change the state of the communicator.
221
225
"""
222
226
if enable is None :
223
227
# guess a default value when not specified
224
228
enable = self .available
225
229
230
+ if stream is None :
231
+ stream = self .stream
232
+
226
233
old_disable = self .disabled
234
+ old_stream = self .stream
227
235
236
+ self .stream = stream
228
237
self .disabled = not enable
229
238
yield
230
239
231
240
self .disabled = old_disable
241
+ self .stream = old_stream
0 commit comments