@@ -335,6 +335,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
335
335
tp_rank = get_tensor_model_parallel_rank ()
336
336
output_dim = getattr (param , "output_dim" , None )
337
337
338
+ is_sharded_weight = getattr (param , "is_sharded_weight" , False )
339
+ use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
340
+ # bitsandbytes loads the weights of the specific portion
341
+ # no need to narrow
342
+ is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
343
+
338
344
# Special case for GGUF
339
345
is_gguf_weight = getattr (param , "is_gguf_weight" , False )
340
346
is_gguf_weight_type = getattr (param , "is_gguf_weight_type" , False )
@@ -343,13 +349,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
343
349
344
350
# Materialize GGUF UninitializedParameter
345
351
if is_gguf_weight and isinstance (param , UninitializedParameter ):
346
- param .materialize (loaded_weight .shape , dtype = loaded_weight .dtype )
347
-
348
- use_bitsandbytes_4bit = getattr (param , "use_bitsandbytes_4bit" , False )
349
- is_sharded_weight = getattr (param , "is_sharded_weight" , False )
350
- # bitsandbytes loads the weights of the specific portion
351
- # no need to narrow
352
- is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
352
+ final_shape = list (loaded_weight .shape )
353
+ if output_dim is not None :
354
+ tp_size = get_tensor_model_parallel_world_size ()
355
+ assert final_shape [output_dim ] % tp_size == 0
356
+ final_shape [output_dim ] = final_shape [output_dim ] // tp_size
357
+ param .materialize (final_shape , dtype = loaded_weight .dtype )
353
358
354
359
param_data = param .data
355
360
if output_dim is not None and not is_sharded_weight :
0 commit comments