|
1 | 1 | import gc
|
| 2 | +import logging |
2 | 3 | import threading
|
3 | 4 | import unittest
|
4 | 5 | import weakref
|
5 | 6 | from unittest.mock import MagicMock
|
6 | 7 |
|
| 8 | +from zmq import device |
| 9 | + |
7 | 10 | import numpy as np
|
8 | 11 | import torch
|
9 | 12 |
|
10 | 13 | from transformers.generation.candidate_generator import (
|
11 | 14 | AssistantToTargetTranslator,
|
12 | 15 | AssistantVocabTranslatorCache,
|
13 | 16 | AssistedCandidateGeneratorDifferentTokenizers,
|
| 17 | + UniversalSpeculativeDecodingGenerator |
14 | 18 | )
|
15 | 19 |
|
| 20 | +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
| 21 | + |
16 | 22 |
|
17 | 23 | class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
|
18 | 24 | def test_no_intersection(self):
|
@@ -256,3 +262,87 @@ def get_translator():
|
256 | 262 | # All translators should be the same instance
|
257 | 263 | for translator in translators:
|
258 | 264 | self.assertIs(translators[0], translator, "All translators should be identical across threads")
|
| 265 | + |
| 266 | + |
| 267 | +class TestUniversalSpeculativeDecoding(unittest.TestCase): |
| 268 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 269 | + |
| 270 | + @classmethod |
| 271 | + def setUpClass(cls): |
| 272 | + cls.assistant_model = AutoModelForCausalLM.from_pretrained( |
| 273 | + "hf-internal-testing/tiny-random-gpt2").to(cls.device) |
| 274 | + cls.main_tokenizer = AutoTokenizer.from_pretrained( |
| 275 | + "meta-llama/Llama-3.2-1B-Instruct") |
| 276 | + cls.assistant_tokenizer = AutoTokenizer.from_pretrained( |
| 277 | + "hf-internal-testing/tiny-random-gpt2") |
| 278 | + cls.generation_config = GenerationConfig() |
| 279 | + |
| 280 | + # Ensure required tokens exist |
| 281 | + if cls.main_tokenizer.pad_token_id is None: |
| 282 | + cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id |
| 283 | + if cls.main_tokenizer.bos_token_id is None: |
| 284 | + cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id |
| 285 | + |
| 286 | + def setUp(self): |
| 287 | + self.input_ids = torch.tensor([[1, 2, 3]]).to(self.device) |
| 288 | + self.model_kwargs = { |
| 289 | + "attention_mask": torch.ones_like(self.input_ids).to(self.device), |
| 290 | + } |
| 291 | + self.generator = UniversalSpeculativeDecodingGenerator( |
| 292 | + input_ids=self.input_ids, |
| 293 | + assistant_model=self.assistant_model, |
| 294 | + target_tokenizer=self.main_tokenizer, |
| 295 | + assistant_tokenizer=self.assistant_tokenizer, |
| 296 | + generation_config=self.generation_config, |
| 297 | + model_kwargs=self.model_kwargs, |
| 298 | + target_vocab_size=self.main_tokenizer.vocab_size, |
| 299 | + ) |
| 300 | + |
| 301 | + def test_basic_generation(self): |
| 302 | + """Test basic speculative decoding works""" |
| 303 | + input_text = "The quick brown fox" |
| 304 | + input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt") |
| 305 | + self.generator.input_ids = input_ids |
| 306 | + candidates, scores = self.generator.get_candidates(input_ids) |
| 307 | + |
| 308 | + self.assertIsNotNone(candidates) |
| 309 | + self.assertIsNotNone(scores) |
| 310 | + self.assertTrue(torch.is_tensor(candidates)) |
| 311 | + self.assertTrue(torch.is_tensor(scores)) |
| 312 | + |
| 313 | + def test_mismatched_vocabularies(self): |
| 314 | + """Test handling of mismatched vocabularies between models""" |
| 315 | + # Create input with tokens present in main but not assistant vocab |
| 316 | + # Find a token that is not in the assistant tokenizer but in |
| 317 | + # the main tokenizer. |
| 318 | + missing_token = next( |
| 319 | + token for token in self.main_tokenizer.get_vocab() |
| 320 | + if token not in self.assistant_tokenizer.get_vocab() and |
| 321 | + token not in self.main_tokenizer.all_special_tokens and |
| 322 | + "reserved_" not in token |
| 323 | + ) |
| 324 | + input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]]) |
| 325 | + self.generator.input_ids = input_ids |
| 326 | + candidates, scores = self.generator.get_candidates(input_ids) |
| 327 | + self.assertIsNotNone(candidates) |
| 328 | + |
| 329 | + def test_speculation_depth(self): |
| 330 | + """Test different speculation depths""" |
| 331 | + input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt") |
| 332 | + self.generator.input_ids = input_ids |
| 333 | + |
| 334 | + for depth in [1, 8, 17]: |
| 335 | + self.generator.num_assistant_tokens = depth |
| 336 | + candidates, scores = self.generator.get_candidates(input_ids) |
| 337 | + self.assertLessEqual( |
| 338 | + candidates.shape[1] - input_ids.shape[1], depth |
| 339 | + ) |
| 340 | + |
| 341 | + def test_device_consistency(self): |
| 342 | + """Test handling of inputs on different devices""" |
| 343 | + if torch.cuda.is_available(): |
| 344 | + input_ids = torch.tensor([[1, 2, 3]]).to( |
| 345 | + self.generator.assistant_model.device) |
| 346 | + self.generator.input_ids = input_ids |
| 347 | + candidates, scores = self.generator.get_candidates(input_ids) |
| 348 | + self.assertEqual(candidates.device, input_ids.device) |
0 commit comments