@@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
61
61
self .truncated_vocab_size = config .truncated_vocab_size
62
62
self .unpadded_vocab_size = self .truncated_vocab_size
63
63
64
- self . lm_heads = nn . ModuleList ([
65
- ParallelLMHead (
64
+ if getattr ( config , "original_lm_head" , False ):
65
+ self . lm_head = ParallelLMHead (
66
66
self .unpadded_vocab_size ,
67
67
config .hidden_size ,
68
68
org_num_embeddings = self .truncated_vocab_size ,
69
69
padding_size = DEFAULT_VOCAB_PADDING_SIZE ,
70
- ) for _ in range (self .config .num_heads )
71
- ])
70
+ )
71
+ self .lm_heads = [
72
+ self .lm_head for _ in range (self .config .num_heads )
73
+ ]
74
+ else :
75
+ self .lm_heads = nn .ModuleList ([
76
+ ParallelLMHead (
77
+ self .unpadded_vocab_size ,
78
+ config .hidden_size ,
79
+ org_num_embeddings = self .truncated_vocab_size ,
80
+ padding_size = DEFAULT_VOCAB_PADDING_SIZE ,
81
+ ) for _ in range (self .config .num_heads )
82
+ ])
72
83
73
84
logit_scale = getattr (config , "logit_scale" , 1.0 )
74
85
self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
@@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
172
183
requires_grad = False )
173
184
elif name in params_dict :
174
185
weights_map [name ] = loaded_weight
186
+ elif (getattr (self .config , "original_lm_head" , False )
187
+ and name == "lm_heads.0.weight" ):
188
+ weights_map ["lm_head.weight" ] = loaded_weight
175
189
176
190
for name , loaded_weight in weights_map .items ():
177
191
if "lm_head" in name and self .token_map is not None and \
0 commit comments