@@ -314,7 +314,8 @@ def forward(
314
314
hidden_states = self .norm (hidden_states )
315
315
return hidden_states
316
316
317
- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
317
+ def load_weights (self , weights : Iterable [tuple [str ,
318
+ torch .Tensor ]]) -> set [str ]:
318
319
stacked_params_mapping = [
319
320
# (param_name, shard_name, shard_id)
320
321
("qkv_proj" , "q_proj" , "q" ),
@@ -325,6 +326,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
325
326
]
326
327
327
328
params_dict = dict (self .named_parameters (remove_duplicate = False ))
329
+ loaded_params : set [str ] = set ()
328
330
for name , loaded_weight in weights :
329
331
if is_pp_missing_parameter (name , self ):
330
332
continue
@@ -347,6 +349,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
347
349
weight_loader = getattr (param , "weight_loader" ,
348
350
default_weight_loader )
349
351
weight_loader (param , loaded_weight )
352
+ loaded_params .add (name )
353
+ return loaded_params
350
354
351
355
352
356
class Olmo2ForCausalLM (nn .Module , SupportsPP ):
0 commit comments