From e3617383fbe7561d66fdf852422b997f305da613 Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 21:09:47 +0000 Subject: [PATCH 1/4] make Eagle model arch config driven --- vllm/config.py | 3 ++- vllm/transformers_utils/configs/eagle.py | 16 +++++++++++++++- vllm/v1/spec_decode/eagle.py | 16 +++++----------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e645103557c..3ed1674b5f3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2401,7 +2401,8 @@ def __post_init__(self): pass else: eagle_config = EAGLEConfig( - self.draft_model_config.hf_config) + self.draft_model_config.hf_config, + method=self.method) self.draft_model_config.hf_config = eagle_config if (self.num_speculative_tokens is not None diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 3a9ad3e0ffc..b09c66a2d22 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -15,6 +15,7 @@ class EAGLEConfig(PretrainedConfig): def __init__(self, model: Union[PretrainedConfig, dict, None] = None, truncated_vocab_size: Optional[int] = None, + method: Optional[str] = 'eagle', **kwargs): model_config: Union[PretrainedConfig, DeepseekV2Config, None] @@ -45,7 +46,20 @@ def __init__(self, if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] else: - kwargs["architectures"] = ["EagleLlamaForCausalLM"] + # Eagle model name should follow naming convention of + # LlamaForCausalLM -> EagleLlamaForCausalLM + if method == "eagle": + kwargs["architectures"] = [ + f"Eagle{arch}" for arch in self.model.architectures + ] + elif method == "eagle3": + kwargs["architectures"] = [ + f"Eagle3{arch}" for arch in self.model.architectures + ] + else: + raise ValueError( + f"Invalid method {method}. \ + Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1de14584d39..71d952803cc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,8 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models import ModelRegistry from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -225,15 +224,10 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - if self.vllm_config.speculative_config.method == "eagle": - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) - else: - assert self.vllm_config.speculative_config.method == "eagle3" - self.model = Eagle3LlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) + draft_model_cls, arch = ModelRegistry.resolve_model_cls(draft_model_config.architectures) + self.model = draft_model_cls( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) loaded_weights = self.model.load_weights( loader.get_all_weights( From bb165a53dfcfb12436871006ce517ecc61a955ed Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 21:15:53 +0000 Subject: [PATCH 2/4] fix lint --- vllm/v1/spec_decode/eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 71d952803cc..8c45ca9a319 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -224,7 +224,8 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - draft_model_cls, arch = ModelRegistry.resolve_model_cls(draft_model_config.architectures) + draft_model_cls, arch = ModelRegistry.resolve_model_cls( + draft_model_config.architectures) self.model = draft_model_cls( model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) From d8db3540d115399aa3e0093ba08e438ea35322ff Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:33:50 +0000 Subject: [PATCH 3/4] pre-commit --- vllm/transformers_utils/configs/eagle.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index b09c66a2d22..e76e8505e76 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -46,19 +46,18 @@ def __init__(self, if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] else: - # Eagle model name should follow naming convention of + # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM if method == "eagle": kwargs["architectures"] = [ f"Eagle{arch}" for arch in self.model.architectures - ] + ] elif method == "eagle3": kwargs["architectures"] = [ f"Eagle3{arch}" for arch in self.model.architectures - ] + ] else: - raise ValueError( - f"Invalid method {method}. \ + raise ValueError(f"Invalid method {method}. \ Supported methods are eagle and eagle3.") super().__init__(**kwargs) From 89fc323c4ebfbf81e7baf67fb4f53db2f5dbb8ef Mon Sep 17 00:00:00 2001 From: ekagra <3116519+ekagra-ranjan@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:53:06 +0000 Subject: [PATCH 4/4] fix mypy --- vllm/transformers_utils/configs/eagle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index e76e8505e76..586d5c7f5e5 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -49,10 +49,14 @@ def __init__(self, # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM if method == "eagle": + assert self.model is not None, \ + "model should not be None when method is eagle" kwargs["architectures"] = [ f"Eagle{arch}" for arch in self.model.architectures ] elif method == "eagle3": + assert self.model is not None, \ + "model should not be None when method is eagle3" kwargs["architectures"] = [ f"Eagle3{arch}" for arch in self.model.architectures ]