File tree 1 file changed +2
-2
lines changed 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -479,7 +479,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
479
479
# ColumnParallelLinear.
480
480
else :
481
481
tensor_model_parallel_rank = get_tensor_model_parallel_rank ()
482
- shard_size = self .output_dim
482
+ shard_size = self .output_size
483
483
start_idx = tensor_model_parallel_rank * shard_size
484
484
end_idx = (tensor_model_parallel_rank + 1 ) * shard_size
485
485
lora_b = lora_b [:, start_idx :end_idx ]
@@ -490,7 +490,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
490
490
if bias is None :
491
491
return bias
492
492
tensor_model_parallel_rank = get_tensor_model_parallel_rank ()
493
- shard_size = self .output_dim
493
+ shard_size = self .output_size
494
494
start_idx = tensor_model_parallel_rank * shard_size
495
495
end_idx = (tensor_model_parallel_rank + 1 ) * shard_size
496
496
bias = bias [start_idx :end_idx ]
You can’t perform that action at this time.
0 commit comments