Skip to content

Commit 343041c

Browse files
authored
[model] Reduce medusa weight (#10454)
Signed-off-by: skylee-01 <[email protected]>
1 parent ed701ca commit 343041c

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

vllm/model_executor/models/medusa.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
6161
self.truncated_vocab_size = config.truncated_vocab_size
6262
self.unpadded_vocab_size = self.truncated_vocab_size
6363

64-
self.lm_heads = nn.ModuleList([
65-
ParallelLMHead(
64+
if getattr(config, "original_lm_head", False):
65+
self.lm_head = ParallelLMHead(
6666
self.unpadded_vocab_size,
6767
config.hidden_size,
6868
org_num_embeddings=self.truncated_vocab_size,
6969
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
70-
) for _ in range(self.config.num_heads)
71-
])
70+
)
71+
self.lm_heads = [
72+
self.lm_head for _ in range(self.config.num_heads)
73+
]
74+
else:
75+
self.lm_heads = nn.ModuleList([
76+
ParallelLMHead(
77+
self.unpadded_vocab_size,
78+
config.hidden_size,
79+
org_num_embeddings=self.truncated_vocab_size,
80+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
81+
) for _ in range(self.config.num_heads)
82+
])
7283

7384
logit_scale = getattr(config, "logit_scale", 1.0)
7485
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
@@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
172183
requires_grad=False)
173184
elif name in params_dict:
174185
weights_map[name] = loaded_weight
186+
elif (getattr(self.config, "original_lm_head", False)
187+
and name == "lm_heads.0.weight"):
188+
weights_map["lm_head.weight"] = loaded_weight
175189

176190
for name, loaded_weight in weights_map.items():
177191
if "lm_head" in name and self.token_map is not None and\

0 commit comments

Comments
 (0)