|
43 | 43 | from vllm.sequence import IntermediateTensors
|
44 | 44 |
|
45 | 45 | from .interfaces import SupportsPP
|
46 |
| -from .utils import (is_pp_missing_parameter, |
| 46 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
47 | 47 | make_empty_intermediate_tensors_factory, make_layers,
|
48 | 48 | maybe_prefix)
|
49 | 49 |
|
@@ -235,6 +235,35 @@ def forward(
|
235 | 235 | hidden_states = self.ln_f(hidden_states)
|
236 | 236 | return hidden_states
|
237 | 237 |
|
| 238 | + def load_weights(self, weights: Iterable[tuple[str, |
| 239 | + torch.Tensor]]) -> set[str]: |
| 240 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 241 | + loaded_params: set[str] = set() |
| 242 | + for name, loaded_weight in weights: |
| 243 | + if ".attn.bias" in name or ".attn.masked_bias" in name: |
| 244 | + # Skip attention mask. |
| 245 | + # NOTE: "c_attn.bias" should not be skipped. |
| 246 | + continue |
| 247 | + |
| 248 | + if is_pp_missing_parameter(name, self): |
| 249 | + continue |
| 250 | + |
| 251 | + param = params_dict[name] |
| 252 | + # The HF's GPT-2 implementation uses Conv1D instead of Linear. |
| 253 | + # Because of this, we need to transpose the weights. |
| 254 | + # Note(zhuohan): the logic below might break quantized models. |
| 255 | + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: |
| 256 | + if conv1d_weight_name not in name: |
| 257 | + continue |
| 258 | + if not name.endswith(".weight"): |
| 259 | + continue |
| 260 | + loaded_weight = loaded_weight.t() |
| 261 | + weight_loader = getattr(param, "weight_loader", |
| 262 | + default_weight_loader) |
| 263 | + weight_loader(param, loaded_weight) |
| 264 | + loaded_params.add(name) |
| 265 | + return loaded_params |
| 266 | + |
238 | 267 |
|
239 | 268 | class GPT2LMHeadModel(nn.Module, SupportsPP):
|
240 | 269 |
|
@@ -283,32 +312,16 @@ def compute_logits(
|
283 | 312 |
|
284 | 313 | def load_weights(self, weights: Iterable[tuple[str,
|
285 | 314 | torch.Tensor]]) -> set[str]:
|
286 |
| - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
287 |
| - loaded_params: set[str] = set() |
288 |
| - for name, loaded_weight in weights: |
289 |
| - if ".attn.bias" in name or ".attn.masked_bias" in name: |
290 |
| - # Skip attention mask. |
291 |
| - # NOTE: "c_attn.bias" should not be skipped. |
292 |
| - continue |
293 |
| - if not name.startswith("transformer.") and not name.startswith( |
294 |
| - "lm_head"): |
295 |
| - name = "transformer." + name |
296 |
| - |
297 |
| - if is_pp_missing_parameter(name, self): |
298 |
| - continue |
299 |
| - |
300 |
| - param = params_dict[name] |
301 |
| - # The HF's GPT-2 implementation uses Conv1D instead of Linear. |
302 |
| - # Because of this, we need to transpose the weights. |
303 |
| - # Note(zhuohan): the logic below might break quantized models. |
304 |
| - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: |
305 |
| - if conv1d_weight_name not in name: |
306 |
| - continue |
307 |
| - if not name.endswith(".weight"): |
308 |
| - continue |
309 |
| - loaded_weight = loaded_weight.t() |
310 |
| - weight_loader = getattr(param, "weight_loader", |
311 |
| - default_weight_loader) |
312 |
| - weight_loader(param, loaded_weight) |
313 |
| - loaded_params.add(name) |
314 |
| - return loaded_params |
| 315 | + loader = AutoWeightsLoader(self) |
| 316 | + weights = _add_transformer_prefix(weights) |
| 317 | + return loader.load_weights(weights) |
| 318 | + |
| 319 | + |
| 320 | +def _add_transformer_prefix( |
| 321 | + weights: Iterable[tuple[str, torch.Tensor]] |
| 322 | +) -> Iterable[tuple[str, torch.Tensor]]: |
| 323 | + for name, tensor in weights: |
| 324 | + if not name.startswith('transformer.') and not name.startswith( |
| 325 | + "lm_head"): |
| 326 | + name = 'transformer.' + name |
| 327 | + yield name, tensor |
0 commit comments