@@ -52,13 +52,15 @@ def __init__(self,
52
52
group : ProcessGroup ,
53
53
device : Union [int , str , torch .device ],
54
54
max_size = 512 * 1024 * 1024 ,
55
- min_size = 32 * 1024 ) -> None :
55
+ min_size = 128 * 1024 ) -> None :
56
56
"""
57
57
Args:
58
58
group: the process group to work on. If None, it will use the
59
59
default process group.
60
60
device: the device to bind the QuickAllreduce to. If None,
61
61
it will be bind to f"cuda:{local_rank}".
62
+ max_size: max supported size.
63
+ min_size: Less than this size, custom_allreduce is better.
62
64
It is the caller's responsibility to make sure each communicator
63
65
is bind to a unique device, and all communicators in this group
64
66
are in the same node.
@@ -168,40 +170,20 @@ def should_quick_ar(self, inp: torch.Tensor):
168
170
return inp_size < self .max_size # and inp_size > self.min_size
169
171
return False
170
172
171
- def all_reduce (self ,
172
- inp : torch .Tensor ,
173
- * ,
174
- out : torch .Tensor = None ,
175
- registered : bool = False ):
176
- """Performs an out-of-place all reduce.
177
-
178
- If registered is True, this assumes inp's pointer is already
179
- IPC-registered. Otherwise, inp is first copied into a pre-registered
180
- buffer.
181
- """
173
+ def all_reduce (self , inp : torch .Tensor , * , out : torch .Tensor = None ):
174
+ """Performs an out-of-place all reduce."""
182
175
if out is None :
183
176
out = torch .empty_like (inp )
184
- if registered :
185
- ops .all_reduce (self ._ptr , inp , out , 0 , 0 )
186
- else :
187
- # print("qr")
188
- ops .qr_all_reduce (self ._ptr , envs .VLLM_QUICK_ALLREDUCE , inp , out )
177
+ ops .qr_all_reduce (self ._ptr , envs .VLLM_QUICK_ALLREDUCE , inp , out )
189
178
return out
190
179
191
180
def quick_all_reduce (self , input : torch .Tensor ) -> Optional [torch .Tensor ]:
192
181
"""The main allreduce API that provides support for cuda graph."""
193
182
# When quick allreduce is disabled, this will be None.
194
183
if self .disabled or not self .should_quick_ar (input ):
195
184
return None
196
- if self ._IS_CAPTURING :
197
- if torch .cuda .is_current_stream_capturing ():
198
- return self .all_reduce (input , registered = True )
199
- else :
200
- # If warm up, mimic the allocation pattern since quick
201
- # allreduce is out-of-place.
202
- return torch .empty_like (input )
203
- else :
204
- return self .all_reduce (input , registered = False )
185
+
186
+ return self .all_reduce (input )
205
187
206
188
def close (self ):
207
189
'''del self._ptr and del buffer'''
0 commit comments