Skip to content

Commit d18d9c3

Browse files
keyboardAntjmamougauravj14gauravjain14
authored
Universal Speculative Decoding CandidateGenerator (#35029)
* move `TestAssistedCandidateGeneratorDifferentTokenizers` into a new testing file * refactor * NOTHING. add space to rerun github actions tests * remove it... * `UniversalSpeculativeDecodingGenerator` * Use `UniversalSpeculativeDecodingGenerator` when `generation_config.do_sample=True` * assistant tokenizes only the target's new suffix * formatting * fix code * fix code * formatting * add `TestGenerateWithDifferentModels` * `TestGenerateWithDifferentModels` parameterize on `do_sample` * `AssistantVocabMapping` & `AssistantVocabMappingCache` * formatting * `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits` * improve `_get_assistant_to_target_input_ids` & formatting * renaming * WIP: debugging `min_new_tokens` * fix get_target_ids * `UniversalSpeculativeDecodingGenerator` * assistant tokenizes only the target's new suffix * formatting * fix code * fix code * formatting * `TestGenerateWithDifferentModels` parameterize on `do_sample` * `AssistantVocabMapping` & `AssistantVocabMappingCache` * formatting * `AssistantToTargetTranslator`: `get_target_input_ids` & `get_target_logits` * improve `_get_assistant_to_target_input_ids` & formatting * renaming * WIP: debugging `min_new_tokens` * fix get_target_ids * fix device issue * fix get_assistant_input_ids * add `TestAssistedCandidateGeneratorDifferentTokenizers` * formatting * `AssistantVocabTranslatorCache` refactor & tests * revert changes in `src/transformers/generation/logits_process.py` * refactor `AssistedCandidateGenerator` * refactor `AssistedCandidateGeneratorDifferentTokenizers` * formatting * refactor `UniversalSpeculativeDecodingGenerator` * fix negative value for max_new_tokens * fix generation length target + attention_mask vs. assistant + attent * fix device * fix negative max_new_tokens bug * fix UAG * minor * formatting * `AssistedCandidateGeneratorDifferentTokenizers` `lookbehind`s init * resolve conflict & formatting * rerun CI tests * remove space... * remove old code * fix candidate_input_ids device * minor * formatting * Fix prepare + apply (#7) * fix prepare + apply * move to cpu * simplity suppress_tokens * fix bugs and refacatoring * device move * handle self.config.vocab_size > len(target_tokenizer.get_vocab()) * no need to normalize in candidate_generator * address Nadav's comments + minor * optimize device move + SuppressTokensLogitsProcessor * AssistantToTargetTranslator, SuppressTokensLogitsProcessor and tokenizers mapping improvements * padding size * padding improvement * fix and simplify get_target_logits * renaming in get_target_logits * minor * add filter_value and suppress_tokens_id * style + rename * remove TODO * restore original SelectTokensLogitsProcessor with modification * fix style * fix _update_past_and_masks and optimize code * remove assistant_vocab_size arg * fix attention_mask * call _prepare_attention_mask also if not has_past_key_values * handling attention mask for first generation * comment * restore test * remove SelectTokensLogitsProcessor * _update_past_and_masks implementation for USD * Add unittests for Universal Assisted generation * fix style * update tests * Remove unused import and fix `test_speculation_depth` test * exclude special and reserved tokens from tokenizer for UAG * mv `test_universal_assisted_generation.py` to `generation/test_candidate_generator.py` * Remove unused imports and fix style using `make style` (#9) * formatting * Swap gated `meta-llama/llama-3.2` with `allenai/llama` (#10) * Fix space sign disagreement (#12) * default values for AssistantToTargetTranslator fileds * fix space sign * minor * fix test + style * Default values for some fields of assistant to target translator (#11) * default values for AssistantToTargetTranslator fileds * fix * add support to empty logit_processors * Update candidate_generator.py (#15) fix typo * BUG fix in _prepare_assistant_input_ids (#14) * fix _prepare_assistant_input_ids * target_to_assistant_input_ids * Update src/transformers/generation/candidate_generator.py Co-authored-by: Nadav Timor <[email protected]> --------- Co-authored-by: Nadav Timor <[email protected]> * typo (`target_to_assistant_input_ids`) * formatting * merge upstream/main * Fix minor review comments (#16) * Fix: `token_ids.to(torch.int64)` (#18) * tok ids to `torch.int64` (reference: https://huggingface.co/docs/transformers.js/en/api/tokenizers) * `LongTensor` * fix dtype * `assistant_input_ids.to(dtype=torch.long)` * Remove unused import from test_candidate_generator.py * Remove unused import from test_candidate_generator.py * Remove `numpy` import * resolve pr comments (#19) * `AssistantToTargetTranslator` docstring * (per gante's comment) `filter_value` and `suppress_tokens_id` to class constants * update `AssistantToTargetTranslator` docstring * (gante's comment) replace `match-case` * formatting * Fix Joao's comments (#21) * remove threading * fix logits_processor * fix test device * fix style (#23) * Move atm (#24) * move AssistantToTargetTranslator * fixup * fix logit_processor * add atm_translator test * refactor test * remove threading from test * add require_torch in tests * move AssistantVocabTranslatorCache + add tests * ruff fix --------- Co-authored-by: jmamou <[email protected]> Co-authored-by: Gaurav <[email protected]> Co-authored-by: Gaurav Jain <[email protected]> Co-authored-by: gauravjain14 <[email protected]>
1 parent 082834d commit d18d9c3

File tree

3 files changed

+638
-46
lines changed

3 files changed

+638
-46
lines changed

src/transformers/generation/candidate_generator.py

+291-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import copy
17+
import weakref
1718
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
1819

1920
import numpy as np
@@ -27,7 +28,7 @@
2728

2829
from ..cache_utils import DynamicCache
2930
from ..pytorch_utils import isin_mps_friendly
30-
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
31+
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor
3132

3233

3334
if TYPE_CHECKING:
@@ -283,18 +284,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
283284
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
284285
return min_new_tokens, max_new_tokens
285286

286-
def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
287+
def _update_past_and_masks(
288+
self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
289+
) -> bool:
287290
"""Update past key values and attention masks for subsequent generation rounds."""
288291
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
289292
if has_past_key_values:
290293
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
291294
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
292-
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
295+
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens
293296
)
294297
self.assistant_kwargs = _prepare_attention_mask(
295298
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
296299
)
297300
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
301+
298302
return has_past_key_values
299303

300304
def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
@@ -608,6 +612,290 @@ def _process_assistant_outputs(
608612
return new_target_ids
609613

610614

615+
class AssistantToTargetTranslator:
616+
"""
617+
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
618+
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
619+
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
620+
(https://www.arxiv.org/abs/2502.05202).
621+
It maintains mappings between the two vocabularies and handles token/logit conversion.
622+
623+
Args:
624+
target_tokenizer (`PreTrainedTokenizerBase`):
625+
The tokenizer used by the target (main) model.
626+
assistant_tokenizer (`PreTrainedTokenizerBase`):
627+
The tokenizer used by the assistant model.
628+
assistant_model_device (`str`, defaults to "cpu"):
629+
The device where the assistant model is located. Used for placing tensors.
630+
target_vocab_size (`int`, *optional*):
631+
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
632+
"""
633+
634+
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
635+
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
636+
637+
def __init__(
638+
self,
639+
target_tokenizer: "PreTrainedTokenizerBase",
640+
assistant_tokenizer: "PreTrainedTokenizerBase",
641+
target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
642+
assistant_model_device: str = "cpu",
643+
):
644+
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
645+
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
646+
self._assistant_model_device: str = assistant_model_device
647+
self.target_vocab_size: int = target_vocab_size
648+
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
649+
self._get_assistant_to_target_input_ids()
650+
)
651+
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
652+
self.logits_processors: Optional[LogitsProcessorList] = None
653+
if len(self._suppress_input_ids) > 0:
654+
# len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab
655+
self.logits_processors = LogitsProcessorList(
656+
[SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
657+
)
658+
659+
def _get_assistant_to_target_input_ids(self):
660+
target_vocab = self._target_tokenizer.get_vocab()
661+
assistant_vocab = self._assistant_tokenizer.get_vocab()
662+
663+
space_str = " "
664+
target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
665+
if len(target_space_ids) > 0:
666+
target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]
667+
668+
assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
669+
if len(assistant_space_ids) > 0:
670+
assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]
671+
672+
if target_space_sign != assistant_space_sign:
673+
# If the assistant tokenizer has a different space sign than the target tokenizer,
674+
# we need to replace the assistant space sign with the target space sign in the assistant_vocab.
675+
assistant_vocab = {
676+
(
677+
tok.replace(assistant_space_sign, target_space_sign, 1)
678+
if tok.startswith(assistant_space_sign)
679+
else tok
680+
): idx
681+
for tok, idx in assistant_vocab.items()
682+
}
683+
684+
max_assistant_index = max(assistant_vocab.values())
685+
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
686+
target_to_assistant_input_ids: Dict[int, int] = {}
687+
for tok, assistant_id in assistant_vocab.items():
688+
target_id = target_vocab.get(tok)
689+
if target_id is not None:
690+
assistant_to_target_input_ids[assistant_id] = target_id
691+
target_to_assistant_input_ids[target_id] = assistant_id
692+
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
693+
694+
def _get_suppress_input_ids(self) -> list[int]:
695+
"""
696+
Get the input ids that are in the assistant vocab but not in the target vocab.
697+
"""
698+
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]
699+
700+
def get_target_ids(
701+
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
702+
) -> torch.LongTensor:
703+
"""
704+
Return the target candidate ids that correspond to the assistant candidate ids.
705+
Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
706+
Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
707+
"""
708+
709+
num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
710+
if num_new_tokens == 0:
711+
return target_input_ids
712+
else:
713+
transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]]
714+
return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
715+
716+
def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
717+
"""
718+
Return the target logits that correspond to the assistant logits.
719+
"""
720+
721+
target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
722+
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
723+
# Mask for valid indices
724+
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
725+
# Exclude invalid indices
726+
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
727+
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
728+
729+
target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
730+
731+
return target_logits
732+
733+
734+
class AssistantVocabTranslatorCache:
735+
"""
736+
Cache for `AssistantToTargetTranslator` instances. The instances are computed at
737+
pre-processing time, and this cache allows us to avoid recomputing them.
738+
"""
739+
740+
_cache = weakref.WeakKeyDictionary()
741+
742+
@classmethod
743+
def get_translator(
744+
cls,
745+
target_tokenizer: "PreTrainedTokenizerBase",
746+
assistant_tokenizer: "PreTrainedTokenizerBase",
747+
target_vocab_size: int,
748+
assistant_model_device: str = "cpu",
749+
) -> AssistantToTargetTranslator:
750+
assistant_dict = cls._cache.get(target_tokenizer)
751+
if assistant_dict is None:
752+
assistant_dict = weakref.WeakKeyDictionary()
753+
cls._cache[target_tokenizer] = assistant_dict
754+
755+
mapping = assistant_dict.get(assistant_tokenizer)
756+
if mapping is None:
757+
mapping = AssistantToTargetTranslator(
758+
target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device
759+
)
760+
assistant_dict[assistant_tokenizer] = mapping
761+
762+
return mapping
763+
764+
@classmethod
765+
def cleanup(cls):
766+
"""
767+
Clean up dead references in the cache.
768+
This removes entries where either the target_tokenizer or assistant_tokenizer
769+
has been garbage collected.
770+
"""
771+
# Remove entries from the outer cache where the target_tokenizer is no longer alive
772+
dead_keys = [key for key in cls._cache if key is None]
773+
for key in dead_keys:
774+
del cls._cache[key]
775+
776+
# For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
777+
for assistant_dict in cls._cache.values():
778+
dead_keys = [key for key in assistant_dict if key is None]
779+
for key in dead_keys:
780+
del assistant_dict[key]
781+
782+
783+
class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
784+
"""
785+
`CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
786+
for the assistant and main models. This class generates candidates through the use of a smaller model.
787+
"""
788+
789+
def __init__(
790+
self,
791+
input_ids: torch.LongTensor,
792+
assistant_model: "PreTrainedModel",
793+
target_tokenizer: "PreTrainedTokenizerBase",
794+
assistant_tokenizer: "PreTrainedTokenizerBase",
795+
generation_config: "GenerationConfig",
796+
model_kwargs: Dict,
797+
atm_translator: AssistantToTargetTranslator,
798+
inputs_tensor: Optional[torch.Tensor] = None,
799+
logits_processor: "LogitsProcessorList" = None,
800+
):
801+
# Initialize translator before parent class
802+
self._atm_translator = atm_translator
803+
super().__init__(
804+
input_ids,
805+
assistant_model,
806+
target_tokenizer,
807+
assistant_tokenizer,
808+
generation_config,
809+
model_kwargs,
810+
inputs_tensor,
811+
logits_processor,
812+
)
813+
# Track sequence lengths and previous assistant IDs
814+
self._target_seq_len_with_candidates: int = 0
815+
self._prev_assistant_ids: Optional[torch.LongTensor] = None
816+
817+
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
818+
"""
819+
Simplified version of get_candidates that uses the translator cache for token conversion.
820+
"""
821+
target_input_ids = input_ids.to(self.assistant_model.device)
822+
assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
823+
min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)
824+
825+
if max_new_tokens == 0:
826+
return input_ids, None
827+
828+
self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
829+
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
830+
831+
# Ensure scores are returned
832+
generation_args["generation_config"].output_scores = True
833+
generation_args["generation_config"].return_dict_in_generate = True
834+
835+
# Generate and process outputs using translator
836+
if self._atm_translator.logits_processors is not None:
837+
generation_args["logits_processor"] = self._atm_translator.logits_processors
838+
self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)
839+
840+
# Use translator to convert tokens and logits
841+
target_candidate_ids = self._atm_translator.get_target_ids(
842+
assistant_input_ids, target_input_ids, self._prev_assistant_ids
843+
)
844+
self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
845+
target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)
846+
847+
return target_candidate_ids, target_candidate_logits
848+
849+
def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
850+
if self._prev_assistant_ids is None:
851+
# Prepare attention mask for the first generation.
852+
# For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
853+
self.assistant_kwargs = _prepare_attention_mask(
854+
self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
855+
)
856+
return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
857+
858+
def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
859+
"""
860+
Simplified token conversion that only processes new tokens.
861+
"""
862+
# Calculate new tokens since last call
863+
target_seq_len = target_input_ids.shape[-1]
864+
if self._target_seq_len_with_candidates == 0:
865+
new_token_count = target_seq_len
866+
else:
867+
new_token_count = 1
868+
target_new_ids = target_input_ids[:, -new_token_count:]
869+
870+
# Convert the new tokens
871+
assistant_new_ids = None
872+
if self._target_seq_len_with_candidates > 0:
873+
# we have only one new token and we can directly convert it
874+
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
875+
if assistant_new_ids is None:
876+
target_new_text = self.target_tokenizer.batch_decode(
877+
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
878+
)
879+
assistant_new_ids = self.assistant_tokenizer(
880+
target_new_text, add_special_tokens=False, return_tensors="pt"
881+
)["input_ids"].to(self.assistant_model.device)
882+
else:
883+
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)
884+
885+
# Update or initialize assistant IDs
886+
if self._prev_assistant_ids is None:
887+
assistant_input_ids = assistant_new_ids
888+
else:
889+
tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
890+
# If the number of new tokens is greater than zero, truncate the previous assistant IDs
891+
if tokens_to_remove > 0:
892+
self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
893+
assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
894+
assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
895+
896+
return assistant_input_ids, len(assistant_new_ids[0])
897+
898+
611899
class PromptLookupCandidateGenerator(CandidateGenerator):
612900
"""
613901
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up

0 commit comments

Comments
 (0)