Skip to content

Commit dc619cc

Browse files
committed
Iterate
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 43043ca commit dc619cc

File tree

3 files changed

+440
-456
lines changed

3 files changed

+440
-456
lines changed

tests/multimodal/test_processing.py

Lines changed: 71 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
import pytest
2-
from transformers import PreTrainedTokenizerBase
32

4-
from vllm.multimodal.processing import (find_token_match_by_text,
5-
iter_token_runs, replace_by_text)
6-
from vllm.multimodal.utils import cached_get_tokenizer
3+
from vllm.multimodal.processing import iter_token_matches, iter_token_runs
74

85

96
# yapf: disable
107
@pytest.mark.parametrize(
118
("token_ids", "expected"),
129
[
1310
([], []),
14-
([32000, 32000, 32000], [(32000, { "offset": 0, "length": 3 })]),
11+
(
12+
[32000, 32000, 32000],
13+
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
14+
),
1515
(
1616
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
1717
[
18-
(9833, { "offset": 0, "length": 1 }),
19-
(28747, { "offset": 1, "length": 1 }),
20-
(32000, { "offset": 2, "length": 3 }),
21-
(9833, { "offset": 5, "length": 1 }),
22-
(28747, { "offset": 6, "length": 1 }),
23-
(32000, { "offset": 7, "length": 2 }),
24-
(918, { "offset": 9, "length": 1 }),
18+
{ "token_id": 9833, "start_idx": 0, "length": 1 },
19+
{ "token_id": 28747, "start_idx": 1, "length": 1 },
20+
{ "token_id": 32000, "start_idx": 2, "length": 3 },
21+
{ "token_id": 9833, "start_idx": 5, "length": 1 },
22+
{ "token_id": 28747, "start_idx": 6, "length": 1 },
23+
{ "token_id": 32000, "start_idx": 7, "length": 2 },
24+
{ "token_id": 918, "start_idx": 9, "length": 1 },
2525
],
2626
),
2727
],
@@ -30,155 +30,71 @@
3030
def test_iter_token_runs(token_ids, expected):
3131
result = list(iter_token_runs(token_ids))
3232

33-
# Invariants
34-
assert sum(run_info["length"] for _, run_info in result) == len(token_ids)
35-
3633
# Manually constructed results
3734
assert result == expected
3835

39-
40-
@pytest.mark.parametrize("tokenizer_id", [
41-
"llava-hf/llava-1.5-7b-hf",
42-
"meta-llama/Llama-3.2-11B-Vision-Instruct",
43-
"microsoft/Phi-3.5-vision-instruct",
44-
"Qwen/Qwen2-VL-2B-Instruct",
45-
])
46-
@pytest.mark.parametrize(
47-
"text",
48-
[
49-
"What is in this image?",
50-
# LLaVA
51-
"<image>What is in this image?",
52-
"What is<image>in this image?",
53-
"What is in this image?<image>",
54-
# LLama-3.2
55-
"<|image|>What is in this image?",
56-
"What is<|image|>in this image?",
57-
"What is in this image?<|image|>",
58-
# Phi-3-vision
59-
"<image_1>What is in this image?",
60-
"What is<image_1>in this image?",
61-
"What is in this image?<image_1>",
62-
# Qwen2-VL
63-
"<|vision_start|><|image_pad|><|vision_end|>What is in this image?",
64-
"What is<|vision_start|><|image_pad|><|vision_end|>in this image?",
65-
"What is in this image?<|vision_start|><|image_pad|><|vision_end|>",
66-
])
67-
@pytest.mark.parametrize(
68-
"match_str",
69-
[
70-
# No match
71-
"No",
72-
# Has match
73-
"i",
74-
"What",
75-
"What is",
76-
"image",
77-
"image?",
78-
"<image>",
79-
"<|image|>",
80-
"<image_1>",
81-
"<|vision_start|><|image_pad|><|vision_end|>",
82-
"<s>",
83-
"</s>",
84-
])
85-
@pytest.mark.parametrize("add_special_tokens", [True, False])
86-
def test_token_match_by_text(
87-
tokenizer_id,
88-
text,
89-
match_str,
90-
add_special_tokens,
91-
):
92-
tokenizer = cached_get_tokenizer(tokenizer_id)
93-
assert isinstance(tokenizer, PreTrainedTokenizerBase)
94-
95-
token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens)
96-
match = find_token_match_by_text(tokenizer, token_ids, text, match_str)
97-
98-
# These are only shown in the output if the test fails
99-
print("token_ids:", token_ids)
100-
print("match:", match)
101-
10236
# Invariants
103-
if (match_str in text
104-
or match_str in tokenizer.decode(token_ids,
105-
skip_special_tokens=False)):
106-
assert match is not None
107-
match_start_idx, match_end_idx, *_ = match
108-
109-
assert match_str in tokenizer.decode(
110-
token_ids[match_start_idx:match_end_idx],
111-
skip_special_tokens=False,
112-
)
113-
assert match_str not in tokenizer.decode(
114-
token_ids[match_start_idx + 1:match_end_idx],
115-
skip_special_tokens=False,
116-
)
117-
assert match_str not in tokenizer.decode(
118-
token_ids[match_start_idx:match_end_idx - 1],
119-
skip_special_tokens=False,
120-
)
121-
else:
122-
assert match is None
37+
assert sum(run_info["length"] for run_info in result) == len(token_ids)
12338

12439

125-
@pytest.mark.parametrize("tokenizer_id", ["llava-hf/llava-1.5-7b-hf"])
126-
@pytest.mark.parametrize(("input_text", "replacement_count", "expected_text"),
127-
[
128-
("foo", 0, ""),
129-
("bar", 0, "bar"),
130-
("food", 0, "d"),
131-
("foo", 1, "bar"),
132-
("bar", 1, "bar"),
133-
("food", 1, "bard"),
134-
("foo", 2, "barbar"),
135-
("bar", 2, "bar"),
136-
("food", 2, "barbard"),
137-
])
138-
@pytest.mark.parametrize("add_special_tokens", [True, False])
139-
def test_replace_by_text(
140-
tokenizer_id,
141-
input_text,
142-
replacement_count,
143-
expected_text,
144-
add_special_tokens,
145-
):
146-
tokenizer = cached_get_tokenizer(tokenizer_id)
147-
assert isinstance(tokenizer, PreTrainedTokenizerBase)
148-
149-
vocab = tokenizer.get_vocab()
150-
missing_tokens = {"▁foo", "▁bar", "▁food"} - vocab.keys()
151-
assert not missing_tokens, missing_tokens
152-
assert "▁bard" not in vocab
153-
154-
input_ids = tokenizer.encode(input_text,
155-
add_special_tokens=add_special_tokens)
156-
bar_id = vocab["bar"]
157-
158-
output_ids, output_text, replacement = replace_by_text(
159-
tokenizer,
160-
input_ids[:], # Copy
161-
input_text,
162-
"foo",
163-
bar_id,
164-
replacement_count,
165-
)
40+
# yapf: disable
41+
@pytest.mark.parametrize(
42+
("token_ids", "match_ids", "expected"),
43+
[
44+
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
45+
([], [32000], []),
46+
(
47+
[32000, 32000, 32000],
48+
[32000],
49+
[
50+
{ "start_idx": 0, "end_idx": 1 },
51+
{ "start_idx": 1, "end_idx": 2 },
52+
{ "start_idx": 2, "end_idx": 3 },
53+
],
54+
),
55+
(
56+
[32000, 32000, 32000],
57+
[32000, 32000],
58+
[
59+
{ "start_idx": 0, "end_idx": 2 },
60+
{ "start_idx": 1, "end_idx": 3 },
61+
],
62+
),
63+
(
64+
[32000, 32000, 32000],
65+
[32000, 32000, 32000],
66+
[{ "start_idx": 0, "end_idx": 3 }],
67+
),
68+
(
69+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
70+
[28747, 32000],
71+
[
72+
{ "start_idx": 1, "end_idx": 3 },
73+
{ "start_idx": 6, "end_idx": 8 },
74+
],
75+
),
76+
(
77+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
78+
[28747, 32000, 32000, 32000],
79+
[
80+
{ "start_idx": 1, "end_idx": 5 },
81+
],
82+
),
83+
(
84+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
85+
[28747, 0, 32000],
86+
[],
87+
),
88+
],
89+
)
90+
# yapf: enable
91+
def test_iter_token_matches(token_ids, match_ids, expected):
92+
result = list(iter_token_matches(token_ids, match_ids))
16693

167-
# These are only shown in the output if the test fails
168-
print("input_ids:", input_ids)
169-
print("output_ids:", output_ids)
170-
print("output_text:", output_text)
171-
print("replacement:", replacement)
94+
# Manually constructed results
95+
assert [item._asdict() for item in result] == expected
17296

17397
# Invariants
174-
if replacement is None:
175-
assert output_ids == input_ids
176-
else:
177-
offset = replacement["offset"]
178-
repl_len = replacement["length"]
179-
180-
assert output_ids[offset:offset + repl_len] == [bar_id] * repl_len
181-
assert repl_len == replacement_count
182-
183-
# Manually constructed results
184-
assert output_text == expected_text
98+
match_lens = [end - start for start, end in result]
99+
print("match_lens:", match_lens) # Only displayed on error
100+
assert all(match_len == len(match_ids) for match_len in match_lens)

0 commit comments

Comments
 (0)