@@ -627,7 +627,7 @@ def filter(self, indices):
627
627
628
628
class LogitBiasProcessor (LogitsProcessor ):
629
629
"""
630
- `LogitsProcessor ` creates a bias tensor from a dictionary of token IDs and their
630
+ `LogitBiasProcessor ` creates a bias tensor from a dictionary of token IDs and their
631
631
corresponding bias values. Bias are applied to the logits during each forward pass.
632
632
633
633
Supports token IDs provided as strings (e.g., {"9707": -100}).
@@ -656,7 +656,7 @@ def __init__(
656
656
def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
657
657
# Apply bias tensor as a broadcasted addition
658
658
if self .bias_tensor .shape [0 ] != scores .shape [1 ]:
659
- # Fix if the bias tensor is smaller than the scores
659
+ # Pad the bias matrix to match the scores if it's smaller
660
660
self .bias_tensor = torch .nn .functional .pad (
661
661
self .bias_tensor , (0 , scores .shape [1 ] - self .bias_tensor .shape [0 ])
662
662
)
@@ -699,7 +699,7 @@ def __init__(
699
699
def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
700
700
# Apply bias matrix as a broadcasted addition
701
701
if self .bias_matrix .shape [1 ] != scores .shape [1 ]:
702
- # Fix if the bias matrix is smaller than the scores
702
+ # Pad the bias matrix to match the scores if it's smaller
703
703
self .bias_matrix = torch .nn .functional .pad (
704
704
self .bias_matrix , (0 , scores .shape [1 ] - self .bias_matrix .shape [1 ])
705
705
)
0 commit comments