|
2 | 2 | from transformers import PreTrainedTokenizerBase
|
3 | 3 |
|
4 | 4 | from vllm.multimodal.processing import (find_token_match_by_text,
|
5 |
| - iter_token_runs) |
| 5 | + iter_token_runs, replace_by_text) |
6 | 6 | from vllm.multimodal.utils import cached_get_tokenizer
|
7 | 7 |
|
8 | 8 |
|
@@ -43,7 +43,6 @@ def test_iter_token_runs(token_ids, expected):
|
43 | 43 | "microsoft/Phi-3.5-vision-instruct",
|
44 | 44 | "Qwen/Qwen2-VL-2B-Instruct",
|
45 | 45 | ])
|
46 |
| -@pytest.mark.parametrize("add_special_tokens", [True, False]) |
47 | 46 | @pytest.mark.parametrize(
|
48 | 47 | "text",
|
49 | 48 | [
|
@@ -83,11 +82,12 @@ def test_iter_token_runs(token_ids, expected):
|
83 | 82 | "<s>",
|
84 | 83 | "</s>",
|
85 | 84 | ])
|
| 85 | +@pytest.mark.parametrize("add_special_tokens", [True, False]) |
86 | 86 | def test_token_match_by_text(
|
87 | 87 | tokenizer_id,
|
88 |
| - add_special_tokens, |
89 | 88 | text,
|
90 | 89 | match_str,
|
| 90 | + add_special_tokens, |
91 | 91 | ):
|
92 | 92 | tokenizer = cached_get_tokenizer(tokenizer_id)
|
93 | 93 | assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
@@ -120,3 +120,65 @@ def test_token_match_by_text(
|
120 | 120 | )
|
121 | 121 | else:
|
122 | 122 | assert match is None
|
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.parametrize("tokenizer_id", ["llava-hf/llava-1.5-7b-hf"]) |
| 126 | +@pytest.mark.parametrize(("input_text", "replacement_count", "expected_text"), |
| 127 | + [ |
| 128 | + ("foo", 0, ""), |
| 129 | + ("bar", 0, "bar"), |
| 130 | + ("food", 0, "d"), |
| 131 | + ("foo", 1, "bar"), |
| 132 | + ("bar", 1, "bar"), |
| 133 | + ("food", 1, "bard"), |
| 134 | + ("foo", 2, "barbar"), |
| 135 | + ("bar", 2, "bar"), |
| 136 | + ("food", 2, "barbard"), |
| 137 | + ]) |
| 138 | +@pytest.mark.parametrize("add_special_tokens", [True, False]) |
| 139 | +def test_replace_by_text( |
| 140 | + tokenizer_id, |
| 141 | + input_text, |
| 142 | + replacement_count, |
| 143 | + expected_text, |
| 144 | + add_special_tokens, |
| 145 | +): |
| 146 | + tokenizer = cached_get_tokenizer(tokenizer_id) |
| 147 | + assert isinstance(tokenizer, PreTrainedTokenizerBase) |
| 148 | + |
| 149 | + vocab = tokenizer.get_vocab() |
| 150 | + missing_tokens = {"▁foo", "▁bar", "▁food"} - vocab.keys() |
| 151 | + assert not missing_tokens, missing_tokens |
| 152 | + assert "▁bard" not in vocab |
| 153 | + |
| 154 | + input_ids = tokenizer.encode(input_text, |
| 155 | + add_special_tokens=add_special_tokens) |
| 156 | + bar_id = vocab["bar"] |
| 157 | + |
| 158 | + output_ids, output_text, replacement = replace_by_text( |
| 159 | + tokenizer, |
| 160 | + input_ids[:], # Copy |
| 161 | + input_text, |
| 162 | + "foo", |
| 163 | + bar_id, |
| 164 | + replacement_count, |
| 165 | + ) |
| 166 | + |
| 167 | + # These are only shown in the output if the test fails |
| 168 | + print("input_ids:", input_ids) |
| 169 | + print("output_ids:", output_ids) |
| 170 | + print("output_text:", output_text) |
| 171 | + print("replacement:", replacement) |
| 172 | + |
| 173 | + # Invariants |
| 174 | + if replacement is None: |
| 175 | + assert output_ids == input_ids |
| 176 | + else: |
| 177 | + offset = replacement["offset"] |
| 178 | + repl_len = replacement["length"] |
| 179 | + |
| 180 | + assert output_ids[offset:offset + repl_len] == [bar_id] * repl_len |
| 181 | + assert repl_len == replacement_count |
| 182 | + |
| 183 | + # Manually constructed results |
| 184 | + assert output_text == expected_text |
0 commit comments