File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed
transformers_utils/configs Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -52,13 +52,15 @@ def __init__(self,
52
52
assert self .model is not None , \
53
53
"model should not be None when method is eagle"
54
54
kwargs ["architectures" ] = [
55
- f"Eagle{ arch } " for arch in self .model .architectures
55
+ f"Eagle{ arch } " if not arch .startswith ("Eagle" ) \
56
+ else arch for arch in self .model .architectures
56
57
]
57
58
elif method == "eagle3" :
58
59
assert self .model is not None , \
59
60
"model should not be None when method is eagle3"
60
61
kwargs ["architectures" ] = [
61
- f"Eagle3{ arch } " for arch in self .model .architectures
62
+ f"Eagle3{ arch } " if not arch .startswith ("Eagle3" ) \
63
+ else arch for arch in self .model .architectures
62
64
]
63
65
else :
64
66
raise ValueError (f"Invalid method { method } . \
Original file line number Diff line number Diff line change 9
9
from vllm .forward_context import set_forward_context
10
10
from vllm .logger import init_logger
11
11
from vllm .model_executor .model_loader import get_model_loader
12
- from vllm .model_executor .model_loader .utils import set_default_torch_dtype
12
+ from vllm .model_executor .model_loader .utils import (
13
+ process_weights_after_loading , set_default_torch_dtype )
13
14
from vllm .model_executor .models import ModelRegistry
14
15
from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
15
16
from vllm .triton_utils import tl , triton
@@ -308,6 +309,9 @@ def load_model(self, target_model: nn.Module) -> None:
308
309
loaded_weights = self .model .load_weights (
309
310
loader .get_all_weights (draft_model_config , self .model ))
310
311
312
+ process_weights_after_loading (self .model , draft_model_config ,
313
+ target_device )
314
+
311
315
# share embed_tokens with the target model if needed
312
316
if get_pp_group ().world_size == 1 :
313
317
assert "model.embed_tokens.weight" not in loaded_weights , \
You can’t perform that action at this time.
0 commit comments