Skip to content

Commit 2cd4d58

Browse files
authored
[Model] use AutoWeightsLoader for gpt2 (#18625)
Signed-off-by: zt2370 <[email protected]>
1 parent 6d166a8 commit 2cd4d58

File tree

1 file changed

+43
-30
lines changed

1 file changed

+43
-30
lines changed

vllm/model_executor/models/gpt2.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sequence import IntermediateTensors
4444

4545
from .interfaces import SupportsPP
46-
from .utils import (is_pp_missing_parameter,
46+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4747
make_empty_intermediate_tensors_factory, make_layers,
4848
maybe_prefix)
4949

@@ -235,6 +235,35 @@ def forward(
235235
hidden_states = self.ln_f(hidden_states)
236236
return hidden_states
237237

238+
def load_weights(self, weights: Iterable[tuple[str,
239+
torch.Tensor]]) -> set[str]:
240+
params_dict = dict(self.named_parameters(remove_duplicate=False))
241+
loaded_params: set[str] = set()
242+
for name, loaded_weight in weights:
243+
if ".attn.bias" in name or ".attn.masked_bias" in name:
244+
# Skip attention mask.
245+
# NOTE: "c_attn.bias" should not be skipped.
246+
continue
247+
248+
if is_pp_missing_parameter(name, self):
249+
continue
250+
251+
param = params_dict[name]
252+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
253+
# Because of this, we need to transpose the weights.
254+
# Note(zhuohan): the logic below might break quantized models.
255+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
256+
if conv1d_weight_name not in name:
257+
continue
258+
if not name.endswith(".weight"):
259+
continue
260+
loaded_weight = loaded_weight.t()
261+
weight_loader = getattr(param, "weight_loader",
262+
default_weight_loader)
263+
weight_loader(param, loaded_weight)
264+
loaded_params.add(name)
265+
return loaded_params
266+
238267

239268
class GPT2LMHeadModel(nn.Module, SupportsPP):
240269

@@ -283,32 +312,16 @@ def compute_logits(
283312

284313
def load_weights(self, weights: Iterable[tuple[str,
285314
torch.Tensor]]) -> set[str]:
286-
params_dict = dict(self.named_parameters(remove_duplicate=False))
287-
loaded_params: set[str] = set()
288-
for name, loaded_weight in weights:
289-
if ".attn.bias" in name or ".attn.masked_bias" in name:
290-
# Skip attention mask.
291-
# NOTE: "c_attn.bias" should not be skipped.
292-
continue
293-
if not name.startswith("transformer.") and not name.startswith(
294-
"lm_head"):
295-
name = "transformer." + name
296-
297-
if is_pp_missing_parameter(name, self):
298-
continue
299-
300-
param = params_dict[name]
301-
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
302-
# Because of this, we need to transpose the weights.
303-
# Note(zhuohan): the logic below might break quantized models.
304-
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
305-
if conv1d_weight_name not in name:
306-
continue
307-
if not name.endswith(".weight"):
308-
continue
309-
loaded_weight = loaded_weight.t()
310-
weight_loader = getattr(param, "weight_loader",
311-
default_weight_loader)
312-
weight_loader(param, loaded_weight)
313-
loaded_params.add(name)
314-
return loaded_params
315+
loader = AutoWeightsLoader(self)
316+
weights = _add_transformer_prefix(weights)
317+
return loader.load_weights(weights)
318+
319+
320+
def _add_transformer_prefix(
321+
weights: Iterable[tuple[str, torch.Tensor]]
322+
) -> Iterable[tuple[str, torch.Tensor]]:
323+
for name, tensor in weights:
324+
if not name.startswith('transformer.') and not name.startswith(
325+
"lm_head"):
326+
name = 'transformer.' + name
327+
yield name, tensor

0 commit comments

Comments
 (0)