Skip to content

Commit ce69f7f

Browse files
authored
[Bugfix] Fix gpt2 GGUF inference (#12467)
Signed-off-by: Isotr0py <[email protected]>
1 parent 624a1e4 commit ce69f7f

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

vllm/model_executor/models/gpt2.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
258258
self.transformer = GPT2Model(vllm_config=vllm_config,
259259
prefix=maybe_prefix(
260260
prefix, "transformer"))
261+
self.lm_head = ParallelLMHead(self.config.vocab_size,
262+
self.config.hidden_size,
263+
quant_config=quant_config,
264+
prefix=f"{prefix}.lm_head")
261265
if self.config.tie_word_embeddings:
262-
self.lm_head = self.transformer.wte
263-
else:
264-
self.lm_head = ParallelLMHead(self.config.vocab_size,
265-
self.config.hidden_size,
266-
quant_config=quant_config,
267-
prefix=f"{prefix}.lm_head")
266+
self.lm_head = self.lm_head.tie_weights(self.transformer.wte)
267+
268268
self.logits_processor = LogitsProcessor(config.vocab_size)
269269
self.sampler = get_sampler()
270270
self.make_empty_intermediate_tensors = (
@@ -309,15 +309,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
309309
params_dict = dict(self.named_parameters(remove_duplicate=False))
310310
loaded_params: Set[str] = set()
311311
for name, loaded_weight in weights:
312-
if name.startswith("lm_head"):
313-
# GPT-2 ties the weights of the embedding layer and the final
314-
# linear layer.
315-
continue
316312
if ".attn.bias" in name or ".attn.masked_bias" in name:
317313
# Skip attention mask.
318314
# NOTE: "c_attn.bias" should not be skipped.
319315
continue
320-
if not name.startswith("transformer."):
316+
if not name.startswith("transformer.") and not name.startswith(
317+
"lm_head"):
321318
name = "transformer." + name
322319

323320
if is_pp_missing_parameter(name, self):

0 commit comments

Comments
 (0)