@@ -625,110 +625,92 @@ def filter(self, indices):
625
625
return self
626
626
627
627
628
- class LogitBiasProcessor :
629
- """Process logits with logit biases."""
628
+ class LogitBiasProcessor (LogitsProcessor ):
629
+ """
630
+ `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
631
+ corresponding bias values. Bias are applied to the logits during each forward pass.
632
+
633
+ Supports token IDs provided as strings (e.g., {"9707": -100}).
634
+ """
630
635
631
636
def __init__ (
632
- self , logit_biases : Optional [dict ], tokenizer : PreTrainedTokenizerBase
637
+ self ,
638
+ logit_biases : dict ,
639
+ tokenizer : PreTrainedTokenizerBase ,
640
+ device : torch .device ,
633
641
):
634
- self .tokenizer = tokenizer
635
- self .logit_biases = logit_biases or {}
642
+ assert logit_biases , "LogitBiasProcessor requires non-empty logit_biases"
636
643
637
- # Pre-compute token IDs for each token string
638
- self .token_id_mapping = {}
644
+ vocab_size = len (tokenizer )
639
645
640
- def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
641
- # If no logit biases, return scores unchanged
642
- if not self .logit_biases :
643
- return scores
644
-
645
- # Apply bias to the corresponding scores
646
- for token_str , bias_value in self .logit_biases .items ():
647
- # Get token ID, either from cache or by computing it
648
- if token_str not in self .token_id_mapping :
649
- if token_str .isdigit ():
650
- # If the token string is already a numeric ID
651
- token_id = int (token_str )
652
- else :
653
- # Otherwise, use the tokenizer to get the ID
654
- tokens = self .tokenizer .encode (token_str , add_special_tokens = False )
655
- token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
656
-
657
- self .token_id_mapping [token_str ] = token_id
658
-
659
- token_id = self .token_id_mapping [token_str ]
660
-
661
- # Apply bias if token ID is valid
662
- if 0 <= token_id < scores .size (- 1 ):
663
- scores [:, token_id ] += bias_value
646
+ # Convert keys to integers and values to a list
647
+ token_ids = torch .tensor (
648
+ [int (k ) for k in logit_biases .keys ()], dtype = torch .long
649
+ )
650
+ bias_values = torch .tensor (list (logit_biases .values ()), dtype = torch .float )
664
651
665
- return scores
652
+ # Create a tensor and directly copy bias values at the corresponding indices
653
+ self .bias_tensor = torch .zeros (vocab_size , dtype = torch .float )
654
+ self .bias_tensor .index_put_ ((token_ids ,), bias_values , accumulate = True )
666
655
667
- def filter (self , indices ):
668
- """Keep only the logit biases for the specified indices."""
669
- new_logit_biases = {
670
- k : self .logit_biases [k ] for k in indices if k in self .logit_biases
671
- }
672
- return LogitBiasProcessor (new_logit_biases , self .tokenizer )
656
+ def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
657
+ # Apply bias tensor as a broadcasted addition
658
+ if self .bias_tensor .shape [0 ] != scores .shape [1 ]:
659
+ # Fix if the bias tensor is smaller than the scores
660
+ self .bias_tensor = torch .nn .functional .pad (
661
+ self .bias_tensor , (0 , scores .shape [1 ] - self .bias_tensor .shape [0 ])
662
+ )
663
+ scores .add_ (self .bias_tensor .to (device = scores .device , dtype = scores .dtype ))
664
+ return scores
673
665
674
666
675
- class HeterogeneousLogitBiasProcessor :
676
- """Process logits with different logit biases for each sequence in the batch."""
667
+ class HeterogeneousLogitBiasProcessor (LogitsProcessor ):
668
+ """
669
+ Process logits with different logit biases for each sequence in the batch.
670
+ """
677
671
678
672
def __init__ (
679
673
self ,
680
674
logit_biases : List [Optional [dict ]],
681
675
tokenizer : PreTrainedTokenizerBase ,
682
676
device : torch .device ,
683
677
):
684
- self .device = device
685
678
self .tokenizer = tokenizer
686
679
self .logit_biases = logit_biases
687
- self .batch_size = len (logit_biases )
680
+ # import ipdb; ipdb.set_trace()
681
+ self .vocab_size = len (tokenizer )
688
682
689
- # Pre-compute token IDs for each token string
690
- self .token_id_mapping = {}
683
+ # Create batch_size x vocab_size bias matrix
684
+ self .bias_matrix = torch .zeros (
685
+ (len (logit_biases ), self .vocab_size ), dtype = torch .float , device = device
686
+ )
691
687
692
- # Create a mapping of indices that have logit biases
693
- self .indices_with_biases = {
694
- i : bias_dict
695
- for i , bias_dict in enumerate (self .logit_biases )
696
- if bias_dict is not None and len (bias_dict ) > 0
697
- }
688
+ # for each logit bias dictionary, convert keys to integers and values to a list
689
+ for i , logit_bias in enumerate (logit_biases ):
690
+ token_ids = torch .tensor (
691
+ [int (k ) for k in logit_bias .keys ()], dtype = torch .long
692
+ ).to (device = device )
693
+ bias_values = torch .tensor (list (logit_bias .values ()), dtype = torch .float ).to (
694
+ device = device
695
+ )
696
+ # Create a tensor and directly copy bias values at the corresponding indices
697
+ self .bias_matrix [i ].index_put_ ((token_ids ,), bias_values , accumulate = True )
698
698
699
699
def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
700
- # If no indices have biases, return scores unchanged
701
- if not self .indices_with_biases :
702
- return scores
703
-
704
- # For each index with a bias, apply the bias to the corresponding scores
705
- for i , bias_dict in self .indices_with_biases .items ():
706
- for token_str , bias_value in bias_dict .items ():
707
- # Get token ID, either from cache or by computing it
708
- if token_str not in self .token_id_mapping :
709
- if token_str .isdigit ():
710
- # If the token string is already a numeric ID
711
- token_id = int (token_str )
712
- else :
713
- # Otherwise, use the tokenizer to get the ID
714
- tokens = self .tokenizer .encode (
715
- token_str , add_special_tokens = False
716
- )
717
- token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
718
-
719
- self .token_id_mapping [token_str ] = token_id
720
-
721
- token_id = self .token_id_mapping [token_str ]
722
-
723
- # Apply bias if token ID is valid
724
- if 0 <= token_id < scores .size (- 1 ):
725
- scores [i , token_id ] += bias_value
700
+ # Apply bias matrix as a broadcasted addition
701
+ if self .bias_matrix .shape [1 ] != scores .shape [1 ]:
702
+ # Fix if the bias matrix is smaller than the scores
703
+ self .bias_matrix = torch .nn .functional .pad (
704
+ self .bias_matrix , (0 , scores .shape [1 ] - self .bias_matrix .shape [1 ])
705
+ )
726
706
707
+ scores .add_ (self .bias_matrix .to (device = scores .device , dtype = scores .dtype ))
727
708
return scores
728
709
729
- def filter (self , indices : List [int ]):
730
- """Keep only the logit biases for the specified indices."""
710
+ def filter (self , indices ):
731
711
new_logit_biases = [self .logit_biases [i ] for i in indices ]
712
+ if not any (bias and len (bias ) > 0 for bias in new_logit_biases ):
713
+ return None
732
714
return HeterogeneousLogitBiasProcessor (
733
715
new_logit_biases , self .tokenizer , self .device
734
716
)
0 commit comments