@@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
96
96
_groups [group .unique_name ] = weakref .ref (group )
97
97
98
98
99
- if supports_custom_op ():
100
-
101
- def inplace_all_reduce (tensor : torch .Tensor , group_name : str ) -> None :
102
- assert group_name in _groups , f"Group { group_name } is not found."
103
- group = _groups [group_name ]()
104
- if group is None :
105
- raise ValueError (f"Group { group_name } is destroyed." )
106
- group ._all_reduce_in_place (tensor )
107
-
108
- def inplace_all_reduce_fake (tensor : torch .Tensor , group_name : str ) -> None :
109
- return
99
+ def all_reduce (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
100
+ assert group_name in _groups , f"Group { group_name } is not found."
101
+ group = _groups [group_name ]()
102
+ if group is None :
103
+ raise ValueError (f"Group { group_name } is destroyed." )
104
+ return group ._all_reduce_out_place (tensor )
110
105
111
- direct_register_custom_op (
112
- op_name = "inplace_all_reduce" ,
113
- op_func = inplace_all_reduce ,
114
- mutates_args = ["tensor" ],
115
- fake_impl = inplace_all_reduce_fake ,
116
- )
117
106
118
- def outplace_all_reduce (tensor : torch .Tensor ,
119
- group_name : str ) -> torch .Tensor :
120
- assert group_name in _groups , f"Group { group_name } is not found."
121
- group = _groups [group_name ]()
122
- if group is None :
123
- raise ValueError (f"Group { group_name } is destroyed." )
124
- return group ._all_reduce_out_place (tensor )
107
+ def all_reduce_fake (tensor : torch .Tensor , group_name : str ) -> torch .Tensor :
108
+ return torch .empty_like (tensor )
125
109
126
- def outplace_all_reduce_fake (tensor : torch .Tensor ,
127
- group_name : str ) -> torch .Tensor :
128
- return torch .empty_like (tensor )
129
110
111
+ if supports_custom_op ():
130
112
direct_register_custom_op (
131
- op_name = "outplace_all_reduce " ,
132
- op_func = outplace_all_reduce ,
113
+ op_name = "all_reduce " ,
114
+ op_func = all_reduce ,
133
115
mutates_args = [],
134
- fake_impl = outplace_all_reduce_fake ,
116
+ fake_impl = all_reduce_fake ,
135
117
)
136
118
137
119
@@ -317,30 +299,13 @@ def graph_capture(
317
299
stream .wait_stream (curr_stream )
318
300
319
301
with torch .cuda .stream (stream ), maybe_ca_context :
320
- # In graph mode, we have to be very careful about the collective
321
- # operations. The current status is:
322
- # allreduce \ Mode | Eager | Graph |
323
- # --------------------------------------------
324
- # custom allreduce | enabled | enabled |
325
- # PyNccl | disabled| enabled |
326
- # torch.distributed | enabled | disabled|
327
- #
328
- # Note that custom allreduce will have a runtime check, if the
329
- # tensor size is too large, it will fallback to the next
330
- # available option.
331
- # In summary: When using CUDA graph, we use
332
- # either custom all-reduce kernel or pynccl. When not using
333
- # CUDA graph, we use either custom all-reduce kernel or
334
- # PyTorch NCCL. We always prioritize using custom all-reduce
335
- # kernel but fall back to PyTorch or pynccl if it is
336
- # disabled or not supported.
337
302
pynccl_comm = self .pynccl_comm
338
303
maybe_pynccl_context : Any
339
304
if not pynccl_comm :
340
305
maybe_pynccl_context = nullcontext ()
341
306
else :
342
307
maybe_pynccl_context = pynccl_comm .change_state (
343
- enable = True , stream = torch .cuda .current_stream ())
308
+ stream = torch .cuda .current_stream ())
344
309
with maybe_pynccl_context :
345
310
yield graph_capture_context
346
311
@@ -356,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
356
321
coordinator.
357
322
358
323
In addition, PyTorch custom ops do not support mutation or returning
359
- a new tensor in the same op. So we need to figure out if the op is
360
- in-place or out-of-place ahead of time .
324
+ a new tensor in the same op. So we always make the all-reduce operation
325
+ out-of-place.
361
326
"""
362
327
# Bypass the function if we are using only 1 GPU.
363
328
if self .world_size == 1 :
@@ -368,10 +333,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
368
333
ipex .distributed .all_reduce (input_ , group = self .device_group )
369
334
return input_
370
335
371
- if not supports_custom_op ():
372
- self ._all_reduce_in_place (input_ )
373
- return input_
374
-
375
336
if self .tpu_communicator is not None and \
376
337
not self .tpu_communicator .disabled :
377
338
# TPU handles Dynamo with its own logic.
@@ -385,30 +346,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
385
346
not self .xpu_communicator .disabled :
386
347
return self .xpu_communicator .all_reduce (input_ )
387
348
388
- if self .ca_comm is not None and \
389
- not self .ca_comm .disabled and \
390
- self .ca_comm .should_custom_ar (input_ ):
391
- return torch .ops .vllm .outplace_all_reduce (
392
- input_ , group_name = self .unique_name )
393
- else :
394
- torch .ops .vllm .inplace_all_reduce (input_ ,
395
- group_name = self .unique_name )
396
- return input_
349
+ return torch .ops .vllm .all_reduce (input_ , group_name = self .unique_name )
397
350
398
351
def _all_reduce_out_place (self , input_ : torch .Tensor ) -> torch .Tensor :
352
+ # always try custom allreduce first,
353
+ # and then pynccl.
399
354
ca_comm = self .ca_comm
400
- assert ca_comm is not None
401
- assert not ca_comm .disabled
402
- out = ca_comm .custom_all_reduce (input_ )
403
- assert out is not None
404
- return out
405
-
406
- def _all_reduce_in_place (self , input_ : torch .Tensor ) -> None :
355
+ if ca_comm is not None and not ca_comm .disabled and \
356
+ ca_comm .should_custom_ar (input_ ):
357
+ out = ca_comm .custom_all_reduce (input_ )
358
+ assert out is not None
359
+ return out
407
360
pynccl_comm = self .pynccl_comm
408
- if (pynccl_comm is not None and not pynccl_comm .disabled ):
409
- pynccl_comm .all_reduce (input_ )
410
- else :
411
- torch .distributed .all_reduce (input_ , group = self .device_group )
361
+ assert pynccl_comm is not None
362
+ # TODO: pynccl should not use `stream=`
363
+ # it can just always use the current stream.
364
+ out = pynccl_comm .all_reduce (input_ ,
365
+ stream = torch .cuda .current_stream ())
366
+ if out is None :
367
+ # fall back to the default all-reduce using PyTorch.
368
+ # this usually happens during testing.
369
+ # when we run the model, allreduce only happens for the TP
370
+ # group, where we always have either custom allreduce or pynccl.
371
+ out = input_ .clone ()
372
+ torch .distributed .all_reduce (out , group = self .device_group )
373
+ return out
412
374
413
375
def all_gather (self , input_ : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
414
376
world_size = self .world_size
0 commit comments