Skip to content

Commit 2b25b7d

Browse files
authored
Fix initializing GGUF weights for ColumnParallelLinear when using tensor parallel > 1 (#13023)
1 parent 6c4dbe2 commit 2b25b7d

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
335335
tp_rank = get_tensor_model_parallel_rank()
336336
output_dim = getattr(param, "output_dim", None)
337337

338+
is_sharded_weight = getattr(param, "is_sharded_weight", False)
339+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
340+
# bitsandbytes loads the weights of the specific portion
341+
# no need to narrow
342+
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
343+
338344
# Special case for GGUF
339345
is_gguf_weight = getattr(param, "is_gguf_weight", False)
340346
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
@@ -343,13 +349,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
343349

344350
# Materialize GGUF UninitializedParameter
345351
if is_gguf_weight and isinstance(param, UninitializedParameter):
346-
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
347-
348-
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
349-
is_sharded_weight = getattr(param, "is_sharded_weight", False)
350-
# bitsandbytes loads the weights of the specific portion
351-
# no need to narrow
352-
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
352+
final_shape = list(loaded_weight.shape)
353+
if output_dim is not None:
354+
tp_size = get_tensor_model_parallel_world_size()
355+
assert final_shape[output_dim] % tp_size == 0
356+
final_shape[output_dim] = final_shape[output_dim] // tp_size
357+
param.materialize(final_shape, dtype=loaded_weight.dtype)
353358

354359
param_data = param.data
355360
if output_dim is not None and not is_sharded_weight:

0 commit comments

Comments
 (0)