Skip to content

Commit 856d7ab

Browse files
zinccatmzusman
authored andcommitted
[Bugfix] Fix ColumnParallelLinearWithLoRA slice (vllm-project#11708)
Signed-off-by: ZincCat <[email protected]>
1 parent 41552d7 commit 856d7ab

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/lora/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
479479
# ColumnParallelLinear.
480480
else:
481481
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
482-
shard_size = self.output_dim
482+
shard_size = self.output_size
483483
start_idx = tensor_model_parallel_rank * shard_size
484484
end_idx = (tensor_model_parallel_rank + 1) * shard_size
485485
lora_b = lora_b[:, start_idx:end_idx]
@@ -490,7 +490,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
490490
if bias is None:
491491
return bias
492492
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
493-
shard_size = self.output_dim
493+
shard_size = self.output_size
494494
start_idx = tensor_model_parallel_rank * shard_size
495495
end_idx = (tensor_model_parallel_rank + 1) * shard_size
496496
bias = bias[start_idx:end_idx]

0 commit comments

Comments
 (0)