@@ -51,7 +51,6 @@ def __init__(
51
51
if self .world_size == 1 :
52
52
self .available = False
53
53
self .disabled = True
54
- self .stream = None
55
54
return
56
55
try :
57
56
self .nccl = NCCLLibrary (library_path )
@@ -60,7 +59,6 @@ def __init__(
60
59
# e.g. in a non-GPU environment
61
60
self .available = False
62
61
self .disabled = True
63
- self .stream = None
64
62
return
65
63
66
64
self .available = True
@@ -98,12 +96,12 @@ def __init__(
98
96
with torch .cuda .device (device ):
99
97
self .comm : ncclComm_t = self .nccl .ncclCommInitRank (
100
98
self .world_size , self .unique_id , self .rank )
101
- self .stream = torch .cuda .Stream ()
102
99
100
+ stream = torch .cuda .current_stream ()
103
101
# A small all_reduce for warmup.
104
102
data = torch .zeros (1 , device = device )
105
103
self .all_reduce (data )
106
- self . stream .synchronize ()
104
+ stream .synchronize ()
107
105
del data
108
106
109
107
def all_reduce (self ,
@@ -122,7 +120,7 @@ def all_reduce(self,
122
120
out_tensor = torch .empty_like (in_tensor )
123
121
124
122
if stream is None :
125
- stream = self . stream
123
+ stream = torch . cuda . current_stream ()
126
124
self .nccl .ncclAllReduce (buffer_type (in_tensor .data_ptr ()),
127
125
buffer_type (out_tensor .data_ptr ()),
128
126
in_tensor .numel (),
@@ -144,7 +142,7 @@ def all_gather(self,
144
142
f"this nccl communicator is created to work on { self .device } , "
145
143
f"but the input tensor is on { input_tensor .device } " )
146
144
if stream is None :
147
- stream = self . stream
145
+ stream = torch . cuda . current_stream ()
148
146
self .nccl .ncclAllGather (
149
147
buffer_type (input_tensor .data_ptr ()),
150
148
buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
@@ -165,7 +163,7 @@ def reduce_scatter(self,
165
163
f"this nccl communicator is created to work on { self .device } , "
166
164
f"but the input tensor is on { input_tensor .device } " )
167
165
if stream is None :
168
- stream = self . stream
166
+ stream = torch . cuda . current_stream ()
169
167
self .nccl .ncclReduceScatter (
170
168
buffer_type (input_tensor .data_ptr ()),
171
169
buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
@@ -180,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
180
178
f"this nccl communicator is created to work on { self .device } , "
181
179
f"but the input tensor is on { tensor .device } " )
182
180
if stream is None :
183
- stream = self . stream
181
+ stream = torch . cuda . current_stream ()
184
182
self .nccl .ncclSend (buffer_type (tensor .data_ptr ()), tensor .numel (),
185
183
ncclDataTypeEnum .from_torch (tensor .dtype ), dst ,
186
184
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -192,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
192
190
f"this nccl communicator is created to work on { self .device } , "
193
191
f"but the input tensor is on { tensor .device } " )
194
192
if stream is None :
195
- stream = self . stream
193
+ stream = torch . cuda . current_stream ()
196
194
self .nccl .ncclRecv (buffer_type (tensor .data_ptr ()), tensor .numel (),
197
195
ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
198
196
self .comm , cudaStream_t (stream .cuda_stream ))
@@ -204,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
204
202
f"this nccl communicator is created to work on { self .device } , "
205
203
f"but the input tensor is on { tensor .device } " )
206
204
if stream is None :
207
- stream = self . stream
205
+ stream = torch . cuda . current_stream ()
208
206
if src == self .rank :
209
207
sendbuff = buffer_type (tensor .data_ptr ())
210
208
# NCCL requires the sender also to have a receive buffer
@@ -217,25 +215,17 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
217
215
self .comm , cudaStream_t (stream .cuda_stream ))
218
216
219
217
@contextmanager
220
- def change_state (self ,
221
- enable : Optional [bool ] = None ,
222
- stream : Optional [torch .cuda .Stream ] = None ):
218
+ def change_state (self , enable : Optional [bool ] = None ):
223
219
"""
224
220
A context manager to change the state of the communicator.
225
221
"""
226
222
if enable is None :
227
223
# guess a default value when not specified
228
224
enable = self .available
229
225
230
- if stream is None :
231
- stream = self .stream
232
-
233
226
old_disable = self .disabled
234
- old_stream = self .stream
235
227
236
- self .stream = stream
237
228
self .disabled = not enable
238
229
yield
239
230
240
231
self .disabled = old_disable
241
- self .stream = old_stream
0 commit comments