Skip to content

Commit ce6f880

Browse files
ilmarkovilmarkov
authored andcommitted
[Distributed] [ROCM] Fix custom allreduce enable checks (vllm-project#16010)
Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]> Signed-off-by: xinyuxiao <[email protected]>
1 parent d41ffdc commit ce6f880

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

vllm/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,13 +1620,12 @@ def _verify_args(self) -> None:
16201620
if self.use_ray:
16211621
from vllm.executor import ray_utils
16221622
ray_utils.assert_ray_available()
1623-
device_capability = current_platform.get_device_capability()
1624-
if (current_platform.is_rocm() and device_capability is not None
1625-
and device_capability < (9, 4)):
1623+
1624+
if not current_platform.use_custom_allreduce():
16261625
self.disable_custom_all_reduce = True
16271626
logger.info(
16281627
"Disabled the custom all-reduce kernel because it is not "
1629-
"supported on AMD GPUs older than MI300X.")
1628+
"supported on current platform.")
16301629
if self.ray_workers_use_nsight and not self.use_ray:
16311630
raise ValueError("Unable to use nsight profiling unless workers "
16321631
"run with Ray.")

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool:
308308
def supports_v1(cls, model_config: ModelConfig) -> bool:
309309
return True
310310

311+
@classmethod
312+
def use_custom_allreduce(cls) -> bool:
313+
return True
314+
311315

312316
# NVML utils
313317
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
379379
"""
380380
return False
381381

382+
@classmethod
383+
def use_custom_allreduce(cls) -> bool:
384+
"""
385+
Returns if custom allreduce is supported on the current platform
386+
"""
387+
return False
388+
382389

383390
class UnspecifiedPlatform(Platform):
384391
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,10 @@ def fp8_dtype(cls) -> torch.dtype:
302302
def supports_v1(cls, model_config: ModelConfig) -> bool:
303303
# V1 support on AMD gpus is experimental
304304
return True
305+
306+
@classmethod
307+
def use_custom_allreduce(cls) -> bool:
308+
# We only enable custom allreduce for MI300 series
309+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
310+
supported_archs = ['gfx94']
311+
return any(gfx in gcn_arch for gfx in supported_archs)

0 commit comments

Comments
 (0)