|
1 | 1 | import pytest
|
2 |
| -from transformers import PreTrainedTokenizerBase |
3 | 2 |
|
4 |
| -from vllm.multimodal.processing import (find_token_match_by_text, |
5 |
| - iter_token_runs, replace_by_text) |
6 |
| -from vllm.multimodal.utils import cached_get_tokenizer |
| 3 | +from vllm.multimodal.processing import iter_token_matches, iter_token_runs |
7 | 4 |
|
8 | 5 |
|
9 | 6 | # yapf: disable
|
10 | 7 | @pytest.mark.parametrize(
|
11 | 8 | ("token_ids", "expected"),
|
12 | 9 | [
|
13 | 10 | ([], []),
|
14 |
| - ([32000, 32000, 32000], [(32000, { "offset": 0, "length": 3 })]), |
| 11 | + ( |
| 12 | + [32000, 32000, 32000], |
| 13 | + [{ "token_id": 32000, "start_idx": 0, "length": 3 }], |
| 14 | + ), |
15 | 15 | (
|
16 | 16 | [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
17 | 17 | [
|
18 |
| - (9833, { "offset": 0, "length": 1 }), |
19 |
| - (28747, { "offset": 1, "length": 1 }), |
20 |
| - (32000, { "offset": 2, "length": 3 }), |
21 |
| - (9833, { "offset": 5, "length": 1 }), |
22 |
| - (28747, { "offset": 6, "length": 1 }), |
23 |
| - (32000, { "offset": 7, "length": 2 }), |
24 |
| - (918, { "offset": 9, "length": 1 }), |
| 18 | + { "token_id": 9833, "start_idx": 0, "length": 1 }, |
| 19 | + { "token_id": 28747, "start_idx": 1, "length": 1 }, |
| 20 | + { "token_id": 32000, "start_idx": 2, "length": 3 }, |
| 21 | + { "token_id": 9833, "start_idx": 5, "length": 1 }, |
| 22 | + { "token_id": 28747, "start_idx": 6, "length": 1 }, |
| 23 | + { "token_id": 32000, "start_idx": 7, "length": 2 }, |
| 24 | + { "token_id": 918, "start_idx": 9, "length": 1 }, |
25 | 25 | ],
|
26 | 26 | ),
|
27 | 27 | ],
|
|
30 | 30 | def test_iter_token_runs(token_ids, expected):
|
31 | 31 | result = list(iter_token_runs(token_ids))
|
32 | 32 |
|
33 |
| - # Invariants |
34 |
| - assert sum(run_info["length"] for _, run_info in result) == len(token_ids) |
35 |
| - |
36 | 33 | # Manually constructed results
|
37 | 34 | assert result == expected
|
38 | 35 |
|
39 |
| - |
40 |
| -@pytest.mark.parametrize("tokenizer_id", [ |
41 |
| - "llava-hf/llava-1.5-7b-hf", |
42 |
| - "meta-llama/Llama-3.2-11B-Vision-Instruct", |
43 |
| - "microsoft/Phi-3.5-vision-instruct", |
44 |
| - "Qwen/Qwen2-VL-2B-Instruct", |
45 |
| -]) |
46 |
| -@pytest.mark.parametrize( |
47 |
| - "text", |
48 |
| - [ |
49 |
| - "What is in this image?", |
50 |
| - # LLaVA |
51 |
| - "<image>What is in this image?", |
52 |
| - "What is<image>in this image?", |
53 |
| - "What is in this image?<image>", |
54 |
| - # LLama-3.2 |
55 |
| - "<|image|>What is in this image?", |
56 |
| - "What is<|image|>in this image?", |
57 |
| - "What is in this image?<|image|>", |
58 |
| - # Phi-3-vision |
59 |
| - "<image_1>What is in this image?", |
60 |
| - "What is<image_1>in this image?", |
61 |
| - "What is in this image?<image_1>", |
62 |
| - # Qwen2-VL |
63 |
| - "<|vision_start|><|image_pad|><|vision_end|>What is in this image?", |
64 |
| - "What is<|vision_start|><|image_pad|><|vision_end|>in this image?", |
65 |
| - "What is in this image?<|vision_start|><|image_pad|><|vision_end|>", |
66 |
| - ]) |
67 |
| -@pytest.mark.parametrize( |
68 |
| - "match_str", |
69 |
| - [ |
70 |
| - # No match |
71 |
| - "No", |
72 |
| - # Has match |
73 |
| - "i", |
74 |
| - "What", |
75 |
| - "What is", |
76 |
| - "image", |
77 |
| - "image?", |
78 |
| - "<image>", |
79 |
| - "<|image|>", |
80 |
| - "<image_1>", |
81 |
| - "<|vision_start|><|image_pad|><|vision_end|>", |
82 |
| - "<s>", |
83 |
| - "</s>", |
84 |
| - ]) |
85 |
| -@pytest.mark.parametrize("add_special_tokens", [True, False]) |
86 |
| -def test_token_match_by_text( |
87 |
| - tokenizer_id, |
88 |
| - text, |
89 |
| - match_str, |
90 |
| - add_special_tokens, |
91 |
| -): |
92 |
| - tokenizer = cached_get_tokenizer(tokenizer_id) |
93 |
| - assert isinstance(tokenizer, PreTrainedTokenizerBase) |
94 |
| - |
95 |
| - token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens) |
96 |
| - match = find_token_match_by_text(tokenizer, token_ids, text, match_str) |
97 |
| - |
98 |
| - # These are only shown in the output if the test fails |
99 |
| - print("token_ids:", token_ids) |
100 |
| - print("match:", match) |
101 |
| - |
102 | 36 | # Invariants
|
103 |
| - if (match_str in text |
104 |
| - or match_str in tokenizer.decode(token_ids, |
105 |
| - skip_special_tokens=False)): |
106 |
| - assert match is not None |
107 |
| - match_start_idx, match_end_idx, *_ = match |
108 |
| - |
109 |
| - assert match_str in tokenizer.decode( |
110 |
| - token_ids[match_start_idx:match_end_idx], |
111 |
| - skip_special_tokens=False, |
112 |
| - ) |
113 |
| - assert match_str not in tokenizer.decode( |
114 |
| - token_ids[match_start_idx + 1:match_end_idx], |
115 |
| - skip_special_tokens=False, |
116 |
| - ) |
117 |
| - assert match_str not in tokenizer.decode( |
118 |
| - token_ids[match_start_idx:match_end_idx - 1], |
119 |
| - skip_special_tokens=False, |
120 |
| - ) |
121 |
| - else: |
122 |
| - assert match is None |
| 37 | + assert sum(run_info["length"] for run_info in result) == len(token_ids) |
123 | 38 |
|
124 | 39 |
|
125 |
| -@pytest.mark.parametrize("tokenizer_id", ["llava-hf/llava-1.5-7b-hf"]) |
126 |
| -@pytest.mark.parametrize(("input_text", "replacement_count", "expected_text"), |
127 |
| - [ |
128 |
| - ("foo", 0, ""), |
129 |
| - ("bar", 0, "bar"), |
130 |
| - ("food", 0, "d"), |
131 |
| - ("foo", 1, "bar"), |
132 |
| - ("bar", 1, "bar"), |
133 |
| - ("food", 1, "bard"), |
134 |
| - ("foo", 2, "barbar"), |
135 |
| - ("bar", 2, "bar"), |
136 |
| - ("food", 2, "barbard"), |
137 |
| - ]) |
138 |
| -@pytest.mark.parametrize("add_special_tokens", [True, False]) |
139 |
| -def test_replace_by_text( |
140 |
| - tokenizer_id, |
141 |
| - input_text, |
142 |
| - replacement_count, |
143 |
| - expected_text, |
144 |
| - add_special_tokens, |
145 |
| -): |
146 |
| - tokenizer = cached_get_tokenizer(tokenizer_id) |
147 |
| - assert isinstance(tokenizer, PreTrainedTokenizerBase) |
148 |
| - |
149 |
| - vocab = tokenizer.get_vocab() |
150 |
| - missing_tokens = {"▁foo", "▁bar", "▁food"} - vocab.keys() |
151 |
| - assert not missing_tokens, missing_tokens |
152 |
| - assert "▁bard" not in vocab |
153 |
| - |
154 |
| - input_ids = tokenizer.encode(input_text, |
155 |
| - add_special_tokens=add_special_tokens) |
156 |
| - bar_id = vocab["bar"] |
157 |
| - |
158 |
| - output_ids, output_text, replacement = replace_by_text( |
159 |
| - tokenizer, |
160 |
| - input_ids[:], # Copy |
161 |
| - input_text, |
162 |
| - "foo", |
163 |
| - bar_id, |
164 |
| - replacement_count, |
165 |
| - ) |
| 40 | +# yapf: disable |
| 41 | +@pytest.mark.parametrize( |
| 42 | + ("token_ids", "match_ids", "expected"), |
| 43 | + [ |
| 44 | + ([], [], [{ "start_idx": 0, "end_idx": 0 }]), |
| 45 | + ([], [32000], []), |
| 46 | + ( |
| 47 | + [32000, 32000, 32000], |
| 48 | + [32000], |
| 49 | + [ |
| 50 | + { "start_idx": 0, "end_idx": 1 }, |
| 51 | + { "start_idx": 1, "end_idx": 2 }, |
| 52 | + { "start_idx": 2, "end_idx": 3 }, |
| 53 | + ], |
| 54 | + ), |
| 55 | + ( |
| 56 | + [32000, 32000, 32000], |
| 57 | + [32000, 32000], |
| 58 | + [ |
| 59 | + { "start_idx": 0, "end_idx": 2 }, |
| 60 | + { "start_idx": 1, "end_idx": 3 }, |
| 61 | + ], |
| 62 | + ), |
| 63 | + ( |
| 64 | + [32000, 32000, 32000], |
| 65 | + [32000, 32000, 32000], |
| 66 | + [{ "start_idx": 0, "end_idx": 3 }], |
| 67 | + ), |
| 68 | + ( |
| 69 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 70 | + [28747, 32000], |
| 71 | + [ |
| 72 | + { "start_idx": 1, "end_idx": 3 }, |
| 73 | + { "start_idx": 6, "end_idx": 8 }, |
| 74 | + ], |
| 75 | + ), |
| 76 | + ( |
| 77 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 78 | + [28747, 32000, 32000, 32000], |
| 79 | + [ |
| 80 | + { "start_idx": 1, "end_idx": 5 }, |
| 81 | + ], |
| 82 | + ), |
| 83 | + ( |
| 84 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 85 | + [28747, 0, 32000], |
| 86 | + [], |
| 87 | + ), |
| 88 | + ], |
| 89 | +) |
| 90 | +# yapf: enable |
| 91 | +def test_iter_token_matches(token_ids, match_ids, expected): |
| 92 | + result = list(iter_token_matches(token_ids, match_ids)) |
166 | 93 |
|
167 |
| - # These are only shown in the output if the test fails |
168 |
| - print("input_ids:", input_ids) |
169 |
| - print("output_ids:", output_ids) |
170 |
| - print("output_text:", output_text) |
171 |
| - print("replacement:", replacement) |
| 94 | + # Manually constructed results |
| 95 | + assert [item._asdict() for item in result] == expected |
172 | 96 |
|
173 | 97 | # Invariants
|
174 |
| - if replacement is None: |
175 |
| - assert output_ids == input_ids |
176 |
| - else: |
177 |
| - offset = replacement["offset"] |
178 |
| - repl_len = replacement["length"] |
179 |
| - |
180 |
| - assert output_ids[offset:offset + repl_len] == [bar_id] * repl_len |
181 |
| - assert repl_len == replacement_count |
182 |
| - |
183 |
| - # Manually constructed results |
184 |
| - assert output_text == expected_text |
| 98 | + match_lens = [end - start for start, end in result] |
| 99 | + print("match_lens:", match_lens) # Only displayed on error |
| 100 | + assert all(match_len == len(match_ids) for match_len in match_lens) |
0 commit comments