|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import copy
|
| 17 | +import weakref |
17 | 18 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
18 | 19 |
|
19 | 20 | import numpy as np
|
|
27 | 28 |
|
28 | 29 | from ..cache_utils import DynamicCache
|
29 | 30 | from ..pytorch_utils import isin_mps_friendly
|
30 |
| -from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor |
| 31 | +from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor |
31 | 32 |
|
32 | 33 |
|
33 | 34 | if TYPE_CHECKING:
|
@@ -283,18 +284,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
|
283 | 284 | min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
284 | 285 | return min_new_tokens, max_new_tokens
|
285 | 286 |
|
286 |
| - def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: |
| 287 | + def _update_past_and_masks( |
| 288 | + self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 |
| 289 | + ) -> bool: |
287 | 290 | """Update past key values and attention masks for subsequent generation rounds."""
|
288 | 291 | has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
289 | 292 | if has_past_key_values:
|
290 | 293 | new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
|
291 | 294 | self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
292 |
| - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 |
| 295 | + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens |
293 | 296 | )
|
294 | 297 | self.assistant_kwargs = _prepare_attention_mask(
|
295 | 298 | self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
|
296 | 299 | )
|
297 | 300 | self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
|
| 301 | + |
298 | 302 | return has_past_key_values
|
299 | 303 |
|
300 | 304 | def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
|
@@ -608,6 +612,290 @@ def _process_assistant_outputs(
|
608 | 612 | return new_target_ids
|
609 | 613 |
|
610 | 614 |
|
| 615 | +class AssistantToTargetTranslator: |
| 616 | + """ |
| 617 | + Translates token ids and logits between assistant and target model vocabularies. This class is used to handle |
| 618 | + vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, |
| 619 | + as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" |
| 620 | + (https://www.arxiv.org/abs/2502.05202). |
| 621 | + It maintains mappings between the two vocabularies and handles token/logit conversion. |
| 622 | +
|
| 623 | + Args: |
| 624 | + target_tokenizer (`PreTrainedTokenizerBase`): |
| 625 | + The tokenizer used by the target (main) model. |
| 626 | + assistant_tokenizer (`PreTrainedTokenizerBase`): |
| 627 | + The tokenizer used by the assistant model. |
| 628 | + assistant_model_device (`str`, defaults to "cpu"): |
| 629 | + The device where the assistant model is located. Used for placing tensors. |
| 630 | + target_vocab_size (`int`, *optional*): |
| 631 | + The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. |
| 632 | + """ |
| 633 | + |
| 634 | + FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. |
| 635 | + SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. |
| 636 | + |
| 637 | + def __init__( |
| 638 | + self, |
| 639 | + target_tokenizer: "PreTrainedTokenizerBase", |
| 640 | + assistant_tokenizer: "PreTrainedTokenizerBase", |
| 641 | + target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab() |
| 642 | + assistant_model_device: str = "cpu", |
| 643 | + ): |
| 644 | + self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer |
| 645 | + self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer |
| 646 | + self._assistant_model_device: str = assistant_model_device |
| 647 | + self.target_vocab_size: int = target_vocab_size |
| 648 | + self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( |
| 649 | + self._get_assistant_to_target_input_ids() |
| 650 | + ) |
| 651 | + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() |
| 652 | + self.logits_processors: Optional[LogitsProcessorList] = None |
| 653 | + if len(self._suppress_input_ids) > 0: |
| 654 | + # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab |
| 655 | + self.logits_processors = LogitsProcessorList( |
| 656 | + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] |
| 657 | + ) |
| 658 | + |
| 659 | + def _get_assistant_to_target_input_ids(self): |
| 660 | + target_vocab = self._target_tokenizer.get_vocab() |
| 661 | + assistant_vocab = self._assistant_tokenizer.get_vocab() |
| 662 | + |
| 663 | + space_str = " " |
| 664 | + target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"] |
| 665 | + if len(target_space_ids) > 0: |
| 666 | + target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0] |
| 667 | + |
| 668 | + assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"] |
| 669 | + if len(assistant_space_ids) > 0: |
| 670 | + assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0] |
| 671 | + |
| 672 | + if target_space_sign != assistant_space_sign: |
| 673 | + # If the assistant tokenizer has a different space sign than the target tokenizer, |
| 674 | + # we need to replace the assistant space sign with the target space sign in the assistant_vocab. |
| 675 | + assistant_vocab = { |
| 676 | + ( |
| 677 | + tok.replace(assistant_space_sign, target_space_sign, 1) |
| 678 | + if tok.startswith(assistant_space_sign) |
| 679 | + else tok |
| 680 | + ): idx |
| 681 | + for tok, idx in assistant_vocab.items() |
| 682 | + } |
| 683 | + |
| 684 | + max_assistant_index = max(assistant_vocab.values()) |
| 685 | + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) |
| 686 | + target_to_assistant_input_ids: Dict[int, int] = {} |
| 687 | + for tok, assistant_id in assistant_vocab.items(): |
| 688 | + target_id = target_vocab.get(tok) |
| 689 | + if target_id is not None: |
| 690 | + assistant_to_target_input_ids[assistant_id] = target_id |
| 691 | + target_to_assistant_input_ids[target_id] = assistant_id |
| 692 | + return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids |
| 693 | + |
| 694 | + def _get_suppress_input_ids(self) -> list[int]: |
| 695 | + """ |
| 696 | + Get the input ids that are in the assistant vocab but not in the target vocab. |
| 697 | + """ |
| 698 | + return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] |
| 699 | + |
| 700 | + def get_target_ids( |
| 701 | + self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor |
| 702 | + ) -> torch.LongTensor: |
| 703 | + """ |
| 704 | + Return the target candidate ids that correspond to the assistant candidate ids. |
| 705 | + Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. |
| 706 | + Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. |
| 707 | + """ |
| 708 | + |
| 709 | + num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] |
| 710 | + if num_new_tokens == 0: |
| 711 | + return target_input_ids |
| 712 | + else: |
| 713 | + transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -num_new_tokens:]] |
| 714 | + return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) |
| 715 | + |
| 716 | + def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: |
| 717 | + """ |
| 718 | + Return the target logits that correspond to the assistant logits. |
| 719 | + """ |
| 720 | + |
| 721 | + target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) |
| 722 | + target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) |
| 723 | + # Mask for valid indices |
| 724 | + assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID |
| 725 | + # Exclude invalid indices |
| 726 | + target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] |
| 727 | + valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] |
| 728 | + |
| 729 | + target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask] |
| 730 | + |
| 731 | + return target_logits |
| 732 | + |
| 733 | + |
| 734 | +class AssistantVocabTranslatorCache: |
| 735 | + """ |
| 736 | + Cache for `AssistantToTargetTranslator` instances. The instances are computed at |
| 737 | + pre-processing time, and this cache allows us to avoid recomputing them. |
| 738 | + """ |
| 739 | + |
| 740 | + _cache = weakref.WeakKeyDictionary() |
| 741 | + |
| 742 | + @classmethod |
| 743 | + def get_translator( |
| 744 | + cls, |
| 745 | + target_tokenizer: "PreTrainedTokenizerBase", |
| 746 | + assistant_tokenizer: "PreTrainedTokenizerBase", |
| 747 | + target_vocab_size: int, |
| 748 | + assistant_model_device: str = "cpu", |
| 749 | + ) -> AssistantToTargetTranslator: |
| 750 | + assistant_dict = cls._cache.get(target_tokenizer) |
| 751 | + if assistant_dict is None: |
| 752 | + assistant_dict = weakref.WeakKeyDictionary() |
| 753 | + cls._cache[target_tokenizer] = assistant_dict |
| 754 | + |
| 755 | + mapping = assistant_dict.get(assistant_tokenizer) |
| 756 | + if mapping is None: |
| 757 | + mapping = AssistantToTargetTranslator( |
| 758 | + target_tokenizer, assistant_tokenizer, target_vocab_size, assistant_model_device |
| 759 | + ) |
| 760 | + assistant_dict[assistant_tokenizer] = mapping |
| 761 | + |
| 762 | + return mapping |
| 763 | + |
| 764 | + @classmethod |
| 765 | + def cleanup(cls): |
| 766 | + """ |
| 767 | + Clean up dead references in the cache. |
| 768 | + This removes entries where either the target_tokenizer or assistant_tokenizer |
| 769 | + has been garbage collected. |
| 770 | + """ |
| 771 | + # Remove entries from the outer cache where the target_tokenizer is no longer alive |
| 772 | + dead_keys = [key for key in cls._cache if key is None] |
| 773 | + for key in dead_keys: |
| 774 | + del cls._cache[key] |
| 775 | + |
| 776 | + # For each assistant_dict, remove entries where assistant_tokenizer is no longer alive |
| 777 | + for assistant_dict in cls._cache.values(): |
| 778 | + dead_keys = [key for key in assistant_dict if key is None] |
| 779 | + for key in dead_keys: |
| 780 | + del assistant_dict[key] |
| 781 | + |
| 782 | + |
| 783 | +class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers): |
| 784 | + """ |
| 785 | + `CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers |
| 786 | + for the assistant and main models. This class generates candidates through the use of a smaller model. |
| 787 | + """ |
| 788 | + |
| 789 | + def __init__( |
| 790 | + self, |
| 791 | + input_ids: torch.LongTensor, |
| 792 | + assistant_model: "PreTrainedModel", |
| 793 | + target_tokenizer: "PreTrainedTokenizerBase", |
| 794 | + assistant_tokenizer: "PreTrainedTokenizerBase", |
| 795 | + generation_config: "GenerationConfig", |
| 796 | + model_kwargs: Dict, |
| 797 | + atm_translator: AssistantToTargetTranslator, |
| 798 | + inputs_tensor: Optional[torch.Tensor] = None, |
| 799 | + logits_processor: "LogitsProcessorList" = None, |
| 800 | + ): |
| 801 | + # Initialize translator before parent class |
| 802 | + self._atm_translator = atm_translator |
| 803 | + super().__init__( |
| 804 | + input_ids, |
| 805 | + assistant_model, |
| 806 | + target_tokenizer, |
| 807 | + assistant_tokenizer, |
| 808 | + generation_config, |
| 809 | + model_kwargs, |
| 810 | + inputs_tensor, |
| 811 | + logits_processor, |
| 812 | + ) |
| 813 | + # Track sequence lengths and previous assistant IDs |
| 814 | + self._target_seq_len_with_candidates: int = 0 |
| 815 | + self._prev_assistant_ids: Optional[torch.LongTensor] = None |
| 816 | + |
| 817 | + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: |
| 818 | + """ |
| 819 | + Simplified version of get_candidates that uses the translator cache for token conversion. |
| 820 | + """ |
| 821 | + target_input_ids = input_ids.to(self.assistant_model.device) |
| 822 | + assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids) |
| 823 | + min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids) |
| 824 | + |
| 825 | + if max_new_tokens == 0: |
| 826 | + return input_ids, None |
| 827 | + |
| 828 | + self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) |
| 829 | + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) |
| 830 | + |
| 831 | + # Ensure scores are returned |
| 832 | + generation_args["generation_config"].output_scores = True |
| 833 | + generation_args["generation_config"].return_dict_in_generate = True |
| 834 | + |
| 835 | + # Generate and process outputs using translator |
| 836 | + if self._atm_translator.logits_processors is not None: |
| 837 | + generation_args["logits_processor"] = self._atm_translator.logits_processors |
| 838 | + self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) |
| 839 | + |
| 840 | + # Use translator to convert tokens and logits |
| 841 | + target_candidate_ids = self._atm_translator.get_target_ids( |
| 842 | + assistant_input_ids, target_input_ids, self._prev_assistant_ids |
| 843 | + ) |
| 844 | + self._target_seq_len_with_candidates = target_candidate_ids.shape[-1] |
| 845 | + target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits) |
| 846 | + |
| 847 | + return target_candidate_ids, target_candidate_logits |
| 848 | + |
| 849 | + def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: |
| 850 | + if self._prev_assistant_ids is None: |
| 851 | + # Prepare attention mask for the first generation. |
| 852 | + # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. |
| 853 | + self.assistant_kwargs = _prepare_attention_mask( |
| 854 | + self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder |
| 855 | + ) |
| 856 | + return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) |
| 857 | + |
| 858 | + def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: |
| 859 | + """ |
| 860 | + Simplified token conversion that only processes new tokens. |
| 861 | + """ |
| 862 | + # Calculate new tokens since last call |
| 863 | + target_seq_len = target_input_ids.shape[-1] |
| 864 | + if self._target_seq_len_with_candidates == 0: |
| 865 | + new_token_count = target_seq_len |
| 866 | + else: |
| 867 | + new_token_count = 1 |
| 868 | + target_new_ids = target_input_ids[:, -new_token_count:] |
| 869 | + |
| 870 | + # Convert the new tokens |
| 871 | + assistant_new_ids = None |
| 872 | + if self._target_seq_len_with_candidates > 0: |
| 873 | + # we have only one new token and we can directly convert it |
| 874 | + assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item()) |
| 875 | + if assistant_new_ids is None: |
| 876 | + target_new_text = self.target_tokenizer.batch_decode( |
| 877 | + target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| 878 | + ) |
| 879 | + assistant_new_ids = self.assistant_tokenizer( |
| 880 | + target_new_text, add_special_tokens=False, return_tensors="pt" |
| 881 | + )["input_ids"].to(self.assistant_model.device) |
| 882 | + else: |
| 883 | + assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device) |
| 884 | + |
| 885 | + # Update or initialize assistant IDs |
| 886 | + if self._prev_assistant_ids is None: |
| 887 | + assistant_input_ids = assistant_new_ids |
| 888 | + else: |
| 889 | + tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len |
| 890 | + # If the number of new tokens is greater than zero, truncate the previous assistant IDs |
| 891 | + if tokens_to_remove > 0: |
| 892 | + self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove] |
| 893 | + assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) |
| 894 | + assistant_input_ids = assistant_input_ids.to(dtype=torch.long) |
| 895 | + |
| 896 | + return assistant_input_ids, len(assistant_new_ids[0]) |
| 897 | + |
| 898 | + |
611 | 899 | class PromptLookupCandidateGenerator(CandidateGenerator):
|
612 | 900 | """
|
613 | 901 | `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
|
0 commit comments