Skip to content

Commit 43043ca

Browse files
committed
Add test
Signed-off-by: DarkLight1337 <[email protected]>
1 parent a4b4108 commit 43043ca

File tree

1 file changed

+65
-3
lines changed

1 file changed

+65
-3
lines changed

tests/multimodal/test_processing.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from transformers import PreTrainedTokenizerBase
33

44
from vllm.multimodal.processing import (find_token_match_by_text,
5-
iter_token_runs)
5+
iter_token_runs, replace_by_text)
66
from vllm.multimodal.utils import cached_get_tokenizer
77

88

@@ -43,7 +43,6 @@ def test_iter_token_runs(token_ids, expected):
4343
"microsoft/Phi-3.5-vision-instruct",
4444
"Qwen/Qwen2-VL-2B-Instruct",
4545
])
46-
@pytest.mark.parametrize("add_special_tokens", [True, False])
4746
@pytest.mark.parametrize(
4847
"text",
4948
[
@@ -83,11 +82,12 @@ def test_iter_token_runs(token_ids, expected):
8382
"<s>",
8483
"</s>",
8584
])
85+
@pytest.mark.parametrize("add_special_tokens", [True, False])
8686
def test_token_match_by_text(
8787
tokenizer_id,
88-
add_special_tokens,
8988
text,
9089
match_str,
90+
add_special_tokens,
9191
):
9292
tokenizer = cached_get_tokenizer(tokenizer_id)
9393
assert isinstance(tokenizer, PreTrainedTokenizerBase)
@@ -120,3 +120,65 @@ def test_token_match_by_text(
120120
)
121121
else:
122122
assert match is None
123+
124+
125+
@pytest.mark.parametrize("tokenizer_id", ["llava-hf/llava-1.5-7b-hf"])
126+
@pytest.mark.parametrize(("input_text", "replacement_count", "expected_text"),
127+
[
128+
("foo", 0, ""),
129+
("bar", 0, "bar"),
130+
("food", 0, "d"),
131+
("foo", 1, "bar"),
132+
("bar", 1, "bar"),
133+
("food", 1, "bard"),
134+
("foo", 2, "barbar"),
135+
("bar", 2, "bar"),
136+
("food", 2, "barbard"),
137+
])
138+
@pytest.mark.parametrize("add_special_tokens", [True, False])
139+
def test_replace_by_text(
140+
tokenizer_id,
141+
input_text,
142+
replacement_count,
143+
expected_text,
144+
add_special_tokens,
145+
):
146+
tokenizer = cached_get_tokenizer(tokenizer_id)
147+
assert isinstance(tokenizer, PreTrainedTokenizerBase)
148+
149+
vocab = tokenizer.get_vocab()
150+
missing_tokens = {"▁foo", "▁bar", "▁food"} - vocab.keys()
151+
assert not missing_tokens, missing_tokens
152+
assert "▁bard" not in vocab
153+
154+
input_ids = tokenizer.encode(input_text,
155+
add_special_tokens=add_special_tokens)
156+
bar_id = vocab["bar"]
157+
158+
output_ids, output_text, replacement = replace_by_text(
159+
tokenizer,
160+
input_ids[:], # Copy
161+
input_text,
162+
"foo",
163+
bar_id,
164+
replacement_count,
165+
)
166+
167+
# These are only shown in the output if the test fails
168+
print("input_ids:", input_ids)
169+
print("output_ids:", output_ids)
170+
print("output_text:", output_text)
171+
print("replacement:", replacement)
172+
173+
# Invariants
174+
if replacement is None:
175+
assert output_ids == input_ids
176+
else:
177+
offset = replacement["offset"]
178+
repl_len = replacement["length"]
179+
180+
assert output_ids[offset:offset + repl_len] == [bar_id] * repl_len
181+
assert repl_len == replacement_count
182+
183+
# Manually constructed results
184+
assert output_text == expected_text

0 commit comments

Comments
 (0)