@@ -729,7 +729,22 @@ def sample(
729
729
def load_weights (self , weights : Iterable [Tuple [str ,
730
730
torch .Tensor ]]) -> Set [str ]:
731
731
loader = AutoWeightsLoader (self , skip_prefixes = ["proj_out." ])
732
- loaded_weights = [(name , loaded_weight )
733
- for name , loaded_weight in weights ]
734
732
mapper = WeightsMapper ({".fc1." : ".mlp.fc1." , ".fc2." : ".mlp.fc2." })
735
- return loader .load_weights (loaded_weights , mapper = mapper )
733
+ # add fake zeros bias for k_proj to state_dict
734
+ weights = _create_fake_bias_for_k_proj (weights )
735
+ return loader .load_weights (weights , mapper = mapper )
736
+
737
+
738
+ def _create_fake_bias_for_k_proj (
739
+ weights : Iterable [Tuple [str , torch .Tensor ]]
740
+ ) -> Iterable [Tuple [str , torch .Tensor ]]:
741
+ """
742
+ Create full zeros bias for k_proj weight in self-attention layers.
743
+ So that the bias for k_proj in qkv_proj can be initialized with zeros.
744
+ """
745
+ for name , weight in weights :
746
+ if ".self_attn.k_proj.weight" in name :
747
+ bias = torch .zeros (weight .size (0 ))
748
+ bias_name = name .replace ("weight" , "bias" )
749
+ yield from [(name , weight ), (bias_name , bias )]
750
+ yield name , weight
0 commit comments