Skip to content

Commit a04720b

Browse files
[V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (#18290)
1 parent 7b9d832 commit a04720b

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

vllm/transformers_utils/configs/eagle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ def __init__(self,
5252
assert self.model is not None, \
5353
"model should not be None when method is eagle"
5454
kwargs["architectures"] = [
55-
f"Eagle{arch}" for arch in self.model.architectures
55+
f"Eagle{arch}" if not arch.startswith("Eagle") \
56+
else arch for arch in self.model.architectures
5657
]
5758
elif method == "eagle3":
5859
assert self.model is not None, \
5960
"model should not be None when method is eagle3"
6061
kwargs["architectures"] = [
61-
f"Eagle3{arch}" for arch in self.model.architectures
62+
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
63+
else arch for arch in self.model.architectures
6264
]
6365
else:
6466
raise ValueError(f"Invalid method {method}. \

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm.forward_context import set_forward_context
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model_loader
12-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
12+
from vllm.model_executor.model_loader.utils import (
13+
process_weights_after_loading, set_default_torch_dtype)
1314
from vllm.model_executor.models import ModelRegistry
1415
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1516
from vllm.triton_utils import tl, triton
@@ -308,6 +309,9 @@ def load_model(self, target_model: nn.Module) -> None:
308309
loaded_weights = self.model.load_weights(
309310
loader.get_all_weights(draft_model_config, self.model))
310311

312+
process_weights_after_loading(self.model, draft_model_config,
313+
target_device)
314+
311315
# share embed_tokens with the target model if needed
312316
if get_pp_group().world_size == 1:
313317
assert "model.embed_tokens.weight" not in loaded_weights, \

0 commit comments

Comments
 (0)