Skip to content

[ROCm][Bugfix] Bring back fallback to eager mode removed in #14917, but for ROCm only #15413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum
from vllm.platforms import CpuArchEnum, current_platform
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
Expand Down Expand Up @@ -684,6 +684,13 @@ def _verify_cuda_graph(self) -> None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
ROCM_UNSUPPORTED_MODELS = ['mllama']
if (self.hf_config.model_type in ROCM_UNSUPPORTED_MODELS
and not self.enforce_eager and current_platform.is_rocm()):
logger.warning(
"CUDA graph is not supported for %s on ROCm yet, fallback "
"to the eager mode.", self.hf_config.model_type)
self.enforce_eager = True

def _verify_bnb_config(self) -> None:
"""
Expand Down