Skip to content

Commit ef608c3

Browse files
ilmarkovilmarkov
and
ilmarkov
authored
[Distributed] [ROCM] Fix custom allreduce enable checks (#16010)
Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]>
1 parent 2386803 commit ef608c3

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

vllm/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,13 +1619,12 @@ def _verify_args(self) -> None:
16191619
if self.use_ray:
16201620
from vllm.executor import ray_utils
16211621
ray_utils.assert_ray_available()
1622-
device_capability = current_platform.get_device_capability()
1623-
if (current_platform.is_rocm() and device_capability is not None
1624-
and device_capability < (9, 4)):
1622+
1623+
if not current_platform.use_custom_allreduce():
16251624
self.disable_custom_all_reduce = True
16261625
logger.info(
16271626
"Disabled the custom all-reduce kernel because it is not "
1628-
"supported on AMD GPUs older than MI300X.")
1627+
"supported on current platform.")
16291628
if self.ray_workers_use_nsight and not self.use_ray:
16301629
raise ValueError("Unable to use nsight profiling unless workers "
16311630
"run with Ray.")

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool:
308308
def supports_v1(cls, model_config: ModelConfig) -> bool:
309309
return True
310310

311+
@classmethod
312+
def use_custom_allreduce(cls) -> bool:
313+
return True
314+
311315

312316
# NVML utils
313317
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
379379
"""
380380
return False
381381

382+
@classmethod
383+
def use_custom_allreduce(cls) -> bool:
384+
"""
385+
Returns if custom allreduce is supported on the current platform
386+
"""
387+
return False
388+
382389

383390
class UnspecifiedPlatform(Platform):
384391
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,10 @@ def fp8_dtype(cls) -> torch.dtype:
302302
def supports_v1(cls, model_config: ModelConfig) -> bool:
303303
# V1 support on AMD gpus is experimental
304304
return True
305+
306+
@classmethod
307+
def use_custom_allreduce(cls) -> bool:
308+
# We only enable custom allreduce for MI300 series
309+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
310+
supported_archs = ['gfx94']
311+
return any(gfx in gcn_arch for gfx in supported_archs)

0 commit comments

Comments
 (0)