Skip to content

Commit 7088978

Browse files
authored
Merge pull request huggingface#8 from keyboardAnt/unit_tests_usd
Add unittests for Universal Assisted generation
2 parents e047adf + 701edbb commit 7088978

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/generation/test_candidate_generator.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import gc
2+
import logging
23
import threading
34
import unittest
45
import weakref
56
from unittest.mock import MagicMock
67

8+
from zmq import device
9+
710
import numpy as np
811
import torch
912

1013
from transformers.generation.candidate_generator import (
1114
AssistantToTargetTranslator,
1215
AssistantVocabTranslatorCache,
1316
AssistedCandidateGeneratorDifferentTokenizers,
17+
UniversalSpeculativeDecodingGenerator
1418
)
1519

20+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
21+
1622

1723
class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
1824
def test_no_intersection(self):
@@ -256,3 +262,87 @@ def get_translator():
256262
# All translators should be the same instance
257263
for translator in translators:
258264
self.assertIs(translators[0], translator, "All translators should be identical across threads")
265+
266+
267+
class TestUniversalSpeculativeDecoding(unittest.TestCase):
268+
device = "cuda" if torch.cuda.is_available() else "cpu"
269+
270+
@classmethod
271+
def setUpClass(cls):
272+
cls.assistant_model = AutoModelForCausalLM.from_pretrained(
273+
"hf-internal-testing/tiny-random-gpt2").to(cls.device)
274+
cls.main_tokenizer = AutoTokenizer.from_pretrained(
275+
"meta-llama/Llama-3.2-1B-Instruct")
276+
cls.assistant_tokenizer = AutoTokenizer.from_pretrained(
277+
"hf-internal-testing/tiny-random-gpt2")
278+
cls.generation_config = GenerationConfig()
279+
280+
# Ensure required tokens exist
281+
if cls.main_tokenizer.pad_token_id is None:
282+
cls.main_tokenizer.pad_token_id = cls.main_tokenizer.eos_token_id
283+
if cls.main_tokenizer.bos_token_id is None:
284+
cls.main_tokenizer.bos_token_id = cls.main_tokenizer.eos_token_id
285+
286+
def setUp(self):
287+
self.input_ids = torch.tensor([[1, 2, 3]]).to(self.device)
288+
self.model_kwargs = {
289+
"attention_mask": torch.ones_like(self.input_ids).to(self.device),
290+
}
291+
self.generator = UniversalSpeculativeDecodingGenerator(
292+
input_ids=self.input_ids,
293+
assistant_model=self.assistant_model,
294+
target_tokenizer=self.main_tokenizer,
295+
assistant_tokenizer=self.assistant_tokenizer,
296+
generation_config=self.generation_config,
297+
model_kwargs=self.model_kwargs,
298+
target_vocab_size=self.main_tokenizer.vocab_size,
299+
)
300+
301+
def test_basic_generation(self):
302+
"""Test basic speculative decoding works"""
303+
input_text = "The quick brown fox"
304+
input_ids = self.main_tokenizer.encode(input_text, return_tensors="pt")
305+
self.generator.input_ids = input_ids
306+
candidates, scores = self.generator.get_candidates(input_ids)
307+
308+
self.assertIsNotNone(candidates)
309+
self.assertIsNotNone(scores)
310+
self.assertTrue(torch.is_tensor(candidates))
311+
self.assertTrue(torch.is_tensor(scores))
312+
313+
def test_mismatched_vocabularies(self):
314+
"""Test handling of mismatched vocabularies between models"""
315+
# Create input with tokens present in main but not assistant vocab
316+
# Find a token that is not in the assistant tokenizer but in
317+
# the main tokenizer.
318+
missing_token = next(
319+
token for token in self.main_tokenizer.get_vocab()
320+
if token not in self.assistant_tokenizer.get_vocab() and
321+
token not in self.main_tokenizer.all_special_tokens and
322+
"reserved_" not in token
323+
)
324+
input_ids = torch.tensor([[self.main_tokenizer.convert_tokens_to_ids(missing_token)]])
325+
self.generator.input_ids = input_ids
326+
candidates, scores = self.generator.get_candidates(input_ids)
327+
self.assertIsNotNone(candidates)
328+
329+
def test_speculation_depth(self):
330+
"""Test different speculation depths"""
331+
input_ids = self.main_tokenizer.encode("Test text", return_tensors="pt")
332+
self.generator.input_ids = input_ids
333+
334+
for depth in [1, 8, 17]:
335+
self.generator.num_assistant_tokens = depth
336+
candidates, scores = self.generator.get_candidates(input_ids)
337+
self.assertLessEqual(
338+
candidates.shape[1] - input_ids.shape[1], depth
339+
)
340+
341+
def test_device_consistency(self):
342+
"""Test handling of inputs on different devices"""
343+
if torch.cuda.is_available():
344+
input_ids = torch.tensor([[1, 2, 3]]).to(
345+
self.generator.assistant_model.device)
346+
self.generator.input_ids = input_ids
347+
candidates, scores = self.generator.get_candidates(input_ids)
348+
self.assertEqual(candidates.device, input_ids.device)

0 commit comments

Comments
 (0)