41
41
from vllm .transformers_utils .configs import ChatGLMConfig
42
42
43
43
from .interfaces import SupportsLoRA , SupportsMultiModal , SupportsPP
44
- from .utils import (is_pp_missing_parameter ,
44
+ from .utils import (AutoWeightsLoader , WeightsMapper , is_pp_missing_parameter ,
45
45
make_empty_intermediate_tensors_factory , make_layers ,
46
46
maybe_prefix )
47
47
@@ -605,9 +605,50 @@ def forward(
605
605
return IntermediateTensors ({"hidden_states" : hidden_states })
606
606
return hidden_states
607
607
608
+ def load_weights (self , weights : Iterable [Tuple [str ,
609
+ torch .Tensor ]]) -> Set [str ]:
610
+ stacked_params_mapping = [
611
+ # (param_name, shard_name, shard_id)
612
+ ("linear_proj.merged_proj" , "linear_proj.gate_proj" , 0 ),
613
+ ("linear_proj.merged_proj" , "linear_proj.dense_h_to_4h" , 1 ),
614
+ ]
615
+ params_dict = dict (self .named_parameters ())
616
+ loaded_params : Set [str ] = set ()
617
+
618
+ for name , loaded_weight in weights :
619
+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
620
+ if weight_name not in name :
621
+ continue
622
+ name = name .replace (weight_name , param_name )
623
+ # Skip loading extra bias for GPTQ models.
624
+ if name .endswith (".bias" ) and name not in params_dict :
625
+ continue
626
+ if is_pp_missing_parameter (name , self ):
627
+ continue
628
+ param = params_dict [name ]
629
+ weight_loader = param .weight_loader
630
+ weight_loader (param , loaded_weight , shard_id )
631
+ break
632
+ else :
633
+ if "rotary_pos_emb.inv_freq" in name :
634
+ continue
635
+ if name .endswith (".bias" ) and name not in params_dict :
636
+ continue
637
+ if is_pp_missing_parameter (name , self ):
638
+ continue
639
+ param = params_dict [name ]
640
+ weight_loader = getattr (param , "weight_loader" ,
641
+ default_weight_loader )
642
+ weight_loader (param , loaded_weight )
643
+ loaded_params .add (name )
644
+ return loaded_params
645
+
608
646
609
647
class ChatGLMBaseModel (nn .Module , SupportsLoRA , SupportsPP ):
610
648
649
+ hf_to_vllm_mapper = WeightsMapper (
650
+ orig_to_new_substr = {".word_embeddings" : "" }, )
651
+
611
652
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
612
653
super ().__init__ ()
613
654
config = vllm_config .model_config .hf_config
@@ -660,52 +701,9 @@ def sample(
660
701
next_tokens = self .sampler (logits , sampling_metadata )
661
702
return next_tokens
662
703
663
- def load_weights (self , weights : Iterable [Tuple [str ,
664
- torch .Tensor ]]) -> Set [str ]:
665
- # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
666
- merged_weights_dict : Dict [str , Dict [str , Optional [torch .Tensor ]]] = {
667
- "transformer.vision.linear_proj.merged_proj.weight" : {
668
- "transformer.vision.linear_proj.gate_proj.weight" : None ,
669
- "transformer.vision.linear_proj.dense_h_to_4h.weight" : None ,
670
- }
671
- }
672
-
673
- params_dict = dict (self .named_parameters (remove_duplicate = False ))
674
- loaded_params : Set [str ] = set ()
675
- for name , loaded_weight in weights :
676
- is_weight_to_be_merge = False
677
- for _ , merged_weight_dict in merged_weights_dict .items ():
678
- if name in merged_weight_dict :
679
- assert merged_weight_dict [name ] is None
680
- merged_weight_dict [name ] = loaded_weight
681
- is_weight_to_be_merge = True
682
- if is_weight_to_be_merge :
683
- continue
684
- if "rotary_pos_emb.inv_freq" in name :
685
- continue
686
- if "word_embeddings" in name :
687
- name = name .replace (".word_embeddings" , "" )
688
- # Skip loading extra bias for GPTQ models.
689
- if name .endswith (".bias" ) and name not in params_dict :
690
- continue
691
- if is_pp_missing_parameter (name , self ):
692
- continue
693
- param = params_dict [name ]
694
- weight_loader = getattr (param , "weight_loader" ,
695
- default_weight_loader )
696
- weight_loader (param , loaded_weight )
697
- loaded_params .add (name )
698
-
699
- for combined_name , merged_weight_dict in merged_weights_dict .items ():
700
- if combined_name in params_dict :
701
- param = params_dict [combined_name ]
702
- combined_weight = torch .cat (list (merged_weight_dict .values ()),
703
- dim = 0 )
704
- weight_loader = getattr (param , "weight_loader" ,
705
- default_weight_loader )
706
- weight_loader (param , combined_weight )
707
- loaded_params .add (combined_name )
708
- return loaded_params
704
+ def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
705
+ loader = AutoWeightsLoader (self )
706
+ return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
709
707
710
708
711
709
class ChatGLM (ChatGLMBaseModel ):
@@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):
726
724
727
725
728
726
class ChatGLMV (ChatGLMBaseModel , SupportsMultiModal ):
727
+
729
728
packed_modules_mapping = {
730
729
"query_key_value" : ["query_key_value" ],
731
730
"dense_h_to_4h" : ["dense_h_to_4h" ],
@@ -777,7 +776,7 @@ def __new__(
777
776
) -> None :
778
777
config = vllm_config .model_config .hf_config
779
778
# Initialize VL
780
- if hasattr (config , "visual " ):
779
+ if hasattr (config , "vision_config " ):
781
780
return ChatGLMV (vllm_config = vllm_config , prefix = prefix )
782
781
# Initialize LLM
783
782
else :
0 commit comments