Skip to content

Commit a85c542

Browse files
committed
Test and fix candidates detection
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 47d3dcc commit a85c542

File tree

2 files changed

+202
-168
lines changed

2 files changed

+202
-168
lines changed

tests/multimodal/test_processing.py

Lines changed: 79 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
2+
from transformers import PreTrainedTokenizerBase
23

3-
from vllm.multimodal.processing import apply_placeholders, iter_token_runs
4+
from vllm.multimodal.processing import (find_token_match_by_text,
5+
iter_token_runs)
6+
from vllm.multimodal.utils import cached_get_tokenizer
47

58

69
# yapf: disable
@@ -34,108 +37,86 @@ def test_iter_token_runs(token_ids, expected):
3437
assert result == expected
3538

3639

37-
# yapf: disable
40+
@pytest.mark.parametrize("tokenizer_id", [
41+
"llava-hf/llava-1.5-7b-hf",
42+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
43+
"microsoft/Phi-3.5-vision-instruct",
44+
"Qwen/Qwen2-VL-2B-Instruct",
45+
])
46+
@pytest.mark.parametrize("add_special_tokens", [True, False])
3847
@pytest.mark.parametrize(
39-
(
40-
"token_ids", "match_ids", "replacement_id", "replacement_count",
41-
"expected_new_token_ids", "expected_range",
42-
),
48+
"text",
49+
[
50+
"What is in this image?",
51+
# LLaVA
52+
"<image>What is in this image?",
53+
"What is<image>in this image?",
54+
"What is in this image?<image>",
55+
# LLama-3.2
56+
"<|image|>What is in this image?",
57+
"What is<|image|>in this image?",
58+
"What is in this image?<|image|>",
59+
# Phi-3-vision
60+
"<image_1>What is in this image?",
61+
"What is<image_1>in this image?",
62+
"What is in this image?<image_1>",
63+
# Qwen2-VL
64+
"<|vision_start|><|image_pad|><|vision_end|>What is in this image?",
65+
"What is<|vision_start|><|image_pad|><|vision_end|>in this image?",
66+
"What is in this image?<|vision_start|><|image_pad|><|vision_end|>",
67+
])
68+
@pytest.mark.parametrize(
69+
"match_str",
4370
[
44-
# Empty
45-
(
46-
[], [-1], +1, 0,
47-
[], None,
48-
),
4971
# No match
50-
(
51-
[32000, 32000, 32000], [-1], +1, 0,
52-
[32000, 32000, 32000], None,
53-
),
54-
# Match first
55-
(
56-
[-1, 32000, 32000], [-1], +1, 0,
57-
[32000, 32000], { "offset": 0, "length": 0 },
58-
),
59-
(
60-
[-1, 32000, 32000], [-1], +1, 1,
61-
[+1, 32000, 32000], { "offset": 0, "length": 1 },
62-
),
63-
(
64-
[-1, 32000, 32000], [-1], +1, 2,
65-
[+1, +1, 32000, 32000], { "offset": 0, "length": 2 },
66-
),
67-
# Match middle
68-
(
69-
[32000, -1, 32000], [-1], +1, 0,
70-
[32000, 32000], { "offset": 1, "length": 0 },
71-
),
72-
(
73-
[32000, -1, 32000], [-1], +1, 1,
74-
[32000, +1, 32000], { "offset": 1, "length": 1 },
75-
),
76-
(
77-
[32000, -1, 32000], [-1], +1, 2,
78-
[32000, +1, +1, 32000], { "offset": 1, "length": 2},
79-
),
80-
# Match last
81-
(
82-
[32000, 32000, -1], [-1], +1, 0,
83-
[32000, 32000], { "offset": 2, "length": 0 },
84-
),
85-
(
86-
[32000, 32000, -1], [-1], +1, 1,
87-
[32000, 32000, +1], { "offset": 2, "length": 1 },
88-
),
89-
(
90-
[32000, 32000, -1], [-1], +1, 2,
91-
[32000, 32000, +1, +1], { "offset": 2, "length": 2},
92-
),
93-
# Match all
94-
(
95-
[32000, 32000, 32000], [32000], +1, 0,
96-
[32000, 32000], { "offset": 0, "length": 0 },
97-
),
98-
(
99-
[32000, 32000, 32000], [32000], +1, 1,
100-
[+1, 32000, 32000], { "offset": 0, "length": 1 },
101-
),
102-
(
103-
[32000, 32000, 32000], [32000], +1, 2,
104-
[+1, +1, 32000, 32000], { "offset": 0, "length": 2 },
105-
),
106-
],
107-
)
108-
# yapf: enable
109-
def test_apply_placeholders(
110-
token_ids,
111-
match_ids,
112-
replacement_id,
113-
replacement_count,
114-
expected_new_token_ids,
115-
expected_range,
72+
"No",
73+
# Has match
74+
"i",
75+
"What",
76+
"What is",
77+
"image",
78+
"image?",
79+
"<image>",
80+
"<|image|>",
81+
"<image_1>",
82+
"<|vision_start|><|image_pad|><|vision_end|>",
83+
"<s>",
84+
"</s>",
85+
])
86+
def test_token_match_by_text(
87+
tokenizer_id,
88+
add_special_tokens,
89+
text,
90+
match_str,
11691
):
117-
orig_token_ids = token_ids[:]
118-
119-
placeholder_range = apply_placeholders(
120-
token_ids,
121-
match_ids,
122-
replacement_id,
123-
replacement_count,
124-
)
92+
tokenizer = cached_get_tokenizer(tokenizer_id)
93+
assert isinstance(tokenizer, PreTrainedTokenizerBase)
12594

126-
# Invariants
127-
if placeholder_range is None:
128-
assert orig_token_ids == token_ids
129-
else:
130-
offset = placeholder_range["offset"]
131-
match_len = len(match_ids)
132-
repl_len = placeholder_range["length"]
95+
token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens)
96+
match = find_token_match_by_text(tokenizer, token_ids, text, match_str)
13397

134-
assert orig_token_ids[offset:offset + match_len] == match_ids
98+
# These are only shown in the output if the test fails
99+
print("token_ids:", token_ids)
100+
print("match:", match)
135101

136-
repl_ids = [replacement_id] * replacement_count
137-
assert token_ids[offset:offset + repl_len] == repl_ids
102+
# Invariants
103+
if (match_str in text
104+
or match_str in tokenizer.decode(token_ids,
105+
skip_special_tokens=False)):
106+
assert match is not None
107+
match_start_idx, match_end_idx = match
138108

139-
# Manually constructed results
140-
assert token_ids == expected_new_token_ids
141-
assert placeholder_range == expected_range
109+
assert match_str in tokenizer.decode(
110+
token_ids[match_start_idx:match_end_idx],
111+
skip_special_tokens=False,
112+
)
113+
assert match_str not in tokenizer.decode(
114+
token_ids[match_start_idx + 1:match_end_idx],
115+
skip_special_tokens=False,
116+
)
117+
assert match_str not in tokenizer.decode(
118+
token_ids[match_start_idx:match_end_idx - 1],
119+
skip_special_tokens=False,
120+
)
121+
else:
122+
assert match is None

0 commit comments

Comments
 (0)