Skip to content

[Distributed] [ROCM] Fix custom allreduce enable checks #16010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,13 +1619,12 @@ def _verify_args(self) -> None:
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
device_capability = current_platform.get_device_capability()
if (current_platform.is_rocm() and device_capability is not None
and device_capability < (9, 4)):

if not current_platform.use_custom_allreduce():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs older than MI300X.")
"supported on current platform.")
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def supports_fp8(cls) -> bool:
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True

@classmethod
def use_custom_allreduce(cls) -> bool:
return True


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
"""
return False

@classmethod
def use_custom_allreduce(cls) -> bool:
"""
Returns if custom allreduce is supported on the current platform
"""
return False


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,10 @@ def fp8_dtype(cls) -> torch.dtype:
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True

@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs)