|
47 | 47 | from vllm.sequence import IntermediateTensors
|
48 | 48 |
|
49 | 49 | from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
50 |
| -from .utils import (is_pp_missing_parameter, |
| 50 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
51 | 51 | make_empty_intermediate_tensors_factory, make_layers)
|
52 | 52 |
|
53 | 53 |
|
@@ -321,6 +321,56 @@ def forward(
|
321 | 321 | hidden_states, _ = self.norm(hidden_states, residual)
|
322 | 322 | return hidden_states
|
323 | 323 |
|
| 324 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 325 | + torch.Tensor]]) -> Set[str]: |
| 326 | + stacked_params_mapping = [ |
| 327 | + # (param_name, shard_name, shard_id) |
| 328 | + ("gate_up_proj", "gate_proj", 0), |
| 329 | + ("gate_up_proj", "up_proj", 1), |
| 330 | + ] |
| 331 | + params_dict = dict(self.named_parameters()) |
| 332 | + loaded_params: Set[str] = set() |
| 333 | + for name, loaded_weight in weights: |
| 334 | + if "rotary_emb.inv_freq" in name: |
| 335 | + continue |
| 336 | + if name == "lm_head.weight": |
| 337 | + # Unlike Baichuan, Baichuan2 normalizes the head weights. |
| 338 | + # Refer to: |
| 339 | + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 |
| 340 | + # Distinguish between Baichuan and Baichuan2 by checking the |
| 341 | + # vocab size. This is suggested by |
| 342 | + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 |
| 343 | + is_baichuan2 = self.config.vocab_size == 125696 |
| 344 | + if is_baichuan2: |
| 345 | + loaded_weight = torch.nn.functional.normalize( |
| 346 | + loaded_weight) |
| 347 | + |
| 348 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 349 | + if weight_name not in name: |
| 350 | + continue |
| 351 | + name = name.replace(weight_name, param_name) |
| 352 | + # Skip loading extra bias for GPTQ models. |
| 353 | + if name.endswith(".bias") and name not in params_dict: |
| 354 | + continue |
| 355 | + if is_pp_missing_parameter(name, self): |
| 356 | + continue |
| 357 | + param = params_dict[name] |
| 358 | + weight_loader = param.weight_loader |
| 359 | + weight_loader(param, loaded_weight, shard_id) |
| 360 | + break |
| 361 | + else: |
| 362 | + # Skip loading extra bias for GPTQ models. |
| 363 | + if name.endswith(".bias") and name not in params_dict: |
| 364 | + continue |
| 365 | + if is_pp_missing_parameter(name, self): |
| 366 | + continue |
| 367 | + param = params_dict[name] |
| 368 | + weight_loader = getattr(param, "weight_loader", |
| 369 | + default_weight_loader) |
| 370 | + weight_loader(param, loaded_weight) |
| 371 | + loaded_params.add(name) |
| 372 | + return loaded_params |
| 373 | + |
324 | 374 |
|
325 | 375 | class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
326 | 376 | SupportsQuant):
|
@@ -393,53 +443,8 @@ def sample(
|
393 | 443 |
|
394 | 444 | def load_weights(self, weights: Iterable[Tuple[str,
|
395 | 445 | torch.Tensor]]) -> Set[str]:
|
396 |
| - stacked_params_mapping = [ |
397 |
| - # (param_name, shard_name, shard_id) |
398 |
| - ("gate_up_proj", "gate_proj", 0), |
399 |
| - ("gate_up_proj", "up_proj", 1), |
400 |
| - ] |
401 |
| - params_dict = dict(self.named_parameters()) |
402 |
| - loaded_params: Set[str] = set() |
403 |
| - for name, loaded_weight in weights: |
404 |
| - if "rotary_emb.inv_freq" in name: |
405 |
| - continue |
406 |
| - if name == "lm_head.weight": |
407 |
| - # Unlike Baichuan, Baichuan2 normalizes the head weights. |
408 |
| - # Refer to: |
409 |
| - # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 |
410 |
| - # Distinguish between Baichuan and Baichuan2 by checking the |
411 |
| - # vocab size. This is suggested by |
412 |
| - # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 |
413 |
| - is_baichuan2 = self.config.vocab_size == 125696 |
414 |
| - if is_baichuan2: |
415 |
| - loaded_weight = torch.nn.functional.normalize( |
416 |
| - loaded_weight) |
417 |
| - |
418 |
| - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
419 |
| - if weight_name not in name: |
420 |
| - continue |
421 |
| - name = name.replace(weight_name, param_name) |
422 |
| - # Skip loading extra bias for GPTQ models. |
423 |
| - if name.endswith(".bias") and name not in params_dict: |
424 |
| - continue |
425 |
| - if is_pp_missing_parameter(name, self): |
426 |
| - continue |
427 |
| - param = params_dict[name] |
428 |
| - weight_loader = param.weight_loader |
429 |
| - weight_loader(param, loaded_weight, shard_id) |
430 |
| - break |
431 |
| - else: |
432 |
| - # Skip loading extra bias for GPTQ models. |
433 |
| - if name.endswith(".bias") and name not in params_dict: |
434 |
| - continue |
435 |
| - if is_pp_missing_parameter(name, self): |
436 |
| - continue |
437 |
| - param = params_dict[name] |
438 |
| - weight_loader = getattr(param, "weight_loader", |
439 |
| - default_weight_loader) |
440 |
| - weight_loader(param, loaded_weight) |
441 |
| - loaded_params.add(name) |
442 |
| - return loaded_params |
| 446 | + loader = AutoWeightsLoader(self) |
| 447 | + return loader.load_weights(weights) |
443 | 448 |
|
444 | 449 |
|
445 | 450 | class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
0 commit comments