@@ -258,13 +258,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
258
258
self .transformer = GPT2Model (vllm_config = vllm_config ,
259
259
prefix = maybe_prefix (
260
260
prefix , "transformer" ))
261
+ self .lm_head = ParallelLMHead (self .config .vocab_size ,
262
+ self .config .hidden_size ,
263
+ quant_config = quant_config ,
264
+ prefix = f"{ prefix } .lm_head" )
261
265
if self .config .tie_word_embeddings :
262
- self .lm_head = self .transformer .wte
263
- else :
264
- self .lm_head = ParallelLMHead (self .config .vocab_size ,
265
- self .config .hidden_size ,
266
- quant_config = quant_config ,
267
- prefix = f"{ prefix } .lm_head" )
266
+ self .lm_head = self .lm_head .tie_weights (self .transformer .wte )
267
+
268
268
self .logits_processor = LogitsProcessor (config .vocab_size )
269
269
self .sampler = get_sampler ()
270
270
self .make_empty_intermediate_tensors = (
@@ -309,15 +309,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
309
309
params_dict = dict (self .named_parameters (remove_duplicate = False ))
310
310
loaded_params : Set [str ] = set ()
311
311
for name , loaded_weight in weights :
312
- if name .startswith ("lm_head" ):
313
- # GPT-2 ties the weights of the embedding layer and the final
314
- # linear layer.
315
- continue
316
312
if ".attn.bias" in name or ".attn.masked_bias" in name :
317
313
# Skip attention mask.
318
314
# NOTE: "c_attn.bias" should not be skipped.
319
315
continue
320
- if not name .startswith ("transformer." ):
316
+ if not name .startswith ("transformer." ) and not name .startswith (
317
+ "lm_head" ):
321
318
name = "transformer." + name
322
319
323
320
if is_pp_missing_parameter (name , self ):
0 commit comments