Skip to content

Commit a29daa4

Browse files
committed
Cleanup
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 6ab14df commit a29daa4

File tree

2 files changed

+106
-79
lines changed

2 files changed

+106
-79
lines changed

tests/multimodal/test_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def test_iter_token_runs(token_ids, expected):
4141
print("result:", result)
4242

4343
# Manually constructed results
44-
assert result == expected
44+
assert [item._asdict() for item in result] == expected
4545

4646
# Invariants
47-
assert sum(run_info["length"] for run_info in result) == len(token_ids)
47+
assert sum(run_info.length for run_info in result) == len(token_ids)
4848

4949

5050
# yapf: disable

vllm/multimodal/processing.py

Lines changed: 104 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import re
22
from abc import ABC, abstractmethod
3+
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
34
from dataclasses import dataclass
45
from functools import lru_cache
56
from itertools import groupby
6-
from typing import (Any, Callable, Generic, Iterable, Mapping, NamedTuple,
7-
Optional, Sequence, TypeVar, Union)
7+
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
88

99
import numpy as np
1010
from transformers import BatchFeature
@@ -18,72 +18,21 @@
1818
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
1919
VideoItem)
2020

21-
22-
def _encode(
23-
tokenizer: AnyTokenizer,
24-
text: str,
25-
*,
26-
add_special_tokens: bool = False,
27-
) -> list[int]:
28-
"""
29-
Backend-agnostic equivalent of HF's
30-
:code:`tokenizer.encode(text, add_special_tokens=...)`.
31-
"""
32-
if isinstance(tokenizer, MistralTokenizer):
33-
return tokenizer.tokenizer.encode(text,
34-
bos=add_special_tokens,
35-
eos=add_special_tokens)
36-
37-
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
38-
39-
40-
@lru_cache(maxsize=2048)
41-
def _cached_encode(
42-
tokenizer: AnyTokenizer,
43-
text: str,
44-
*,
45-
add_special_tokens: bool = False,
46-
) -> list[int]:
47-
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
48-
49-
50-
def _decode(
51-
tokenizer: AnyTokenizer,
52-
token_ids: list[int],
53-
*,
54-
skip_special_tokens: bool = False,
55-
) -> str:
56-
"""
57-
Backend-agnostic equivalent of HF's
58-
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
59-
"""
60-
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
61-
62-
63-
@lru_cache(maxsize=2048)
64-
def _cached_decode(
65-
tokenizer: AnyTokenizer,
66-
token_ids: tuple[int, ...],
67-
*,
68-
skip_special_tokens: bool = False,
69-
) -> str:
70-
return _decode(tokenizer,
71-
list(token_ids),
72-
skip_special_tokens=skip_special_tokens)
73-
74-
7521
PromptSegment: TypeAlias = Union[str, list[int]]
7622

7723

7824
def bind_segment(
79-
prompt_segment: PromptSegment,
25+
segment: PromptSegment,
8026
tokenizer: AnyTokenizer,
8127
) -> "_BoundPromptSegment":
28+
"""
29+
Bind a text or token prompt to a tokenizer so that it can be
30+
lazily converted into the other format on demand.
31+
"""
8232
return _BoundPromptSegment(
8333
tokenizer=tokenizer,
84-
_text=prompt_segment if isinstance(prompt_segment, str) else None,
85-
_token_ids=prompt_segment
86-
if isinstance(prompt_segment, list) else None,
34+
_text=segment if isinstance(segment, str) else None,
35+
_token_ids=segment if isinstance(segment, list) else None,
8736
)
8837

8938

@@ -163,6 +112,78 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
163112
"""
164113

165114

115+
def _encode(
116+
tokenizer: AnyTokenizer,
117+
text: str,
118+
*,
119+
add_special_tokens: bool = False,
120+
) -> list[int]:
121+
"""
122+
Backend-agnostic equivalent of HF's
123+
:code:`tokenizer.encode(text, add_special_tokens=...)`.
124+
"""
125+
if isinstance(tokenizer, MistralTokenizer):
126+
return tokenizer.tokenizer.encode(text,
127+
bos=add_special_tokens,
128+
eos=add_special_tokens)
129+
130+
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
131+
132+
133+
@lru_cache(maxsize=2048)
134+
def _cached_encode(
135+
tokenizer: AnyTokenizer,
136+
text: str,
137+
*,
138+
add_special_tokens: bool = False,
139+
) -> list[int]:
140+
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
141+
142+
143+
def _decode(
144+
tokenizer: AnyTokenizer,
145+
token_ids: list[int],
146+
*,
147+
skip_special_tokens: bool = False,
148+
) -> str:
149+
"""
150+
Backend-agnostic equivalent of HF's
151+
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
152+
"""
153+
return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
154+
155+
156+
@lru_cache(maxsize=2048)
157+
def _cached_decode(
158+
tokenizer: AnyTokenizer,
159+
token_ids: tuple[int, ...],
160+
*,
161+
skip_special_tokens: bool = False,
162+
) -> str:
163+
return _decode(tokenizer,
164+
list(token_ids),
165+
skip_special_tokens=skip_special_tokens)
166+
167+
168+
class _HasModalityAttr(Protocol):
169+
modality: str
170+
171+
172+
class _HasModalityProp(Protocol):
173+
174+
@property
175+
def modality(self) -> str:
176+
...
177+
178+
179+
_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
180+
181+
182+
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
183+
"""Convenience function to apply :func:`full_groupby` based on modality."""
184+
return full_groupby(values, key=lambda x: x.modality)
185+
186+
166187
@dataclass
167188
class _BoundPromptSegment:
168189
tokenizer: AnyTokenizer
@@ -237,7 +258,7 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
237258
return multi_data
238259

239260

240-
class _TokenRun(TypedDict):
261+
class _TokenRun(NamedTuple):
241262
token_id: int
242263

243264
start_idx: int
@@ -257,19 +278,20 @@ def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
257278
start_idx += length
258279

259280

260-
class _BoundPlaceholderRange(TypedDict):
281+
class _BoundPlaceholderRange(NamedTuple):
261282
modality: str
262283
offset: int
263284
length: int
264285

265286

266287
def iter_placeholders(
267288
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
268-
new_token_ids: list[int],
289+
token_ids: list[int],
269290
*,
270291
min_placeholder_count: int,
271292
) -> Iterable[_BoundPlaceholderRange]:
272-
repls_by_modality = full_groupby(prompt_repls, key=lambda x: x.modality)
293+
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
294+
repls_by_modality = full_groupby_modality(prompt_repls)
273295

274296
placeholder_ids_by_modality = {
275297
modality: {
@@ -280,15 +302,15 @@ def iter_placeholders(
280302
for modality, repls in repls_by_modality
281303
}
282304

283-
for run_info in iter_token_runs(new_token_ids):
284-
if run_info["length"] > min_placeholder_count:
305+
for run_info in iter_token_runs(token_ids):
306+
if run_info.length > min_placeholder_count:
285307
for (modality,
286308
placeholder_ids) in placeholder_ids_by_modality.items():
287-
if run_info["token_id"] in placeholder_ids:
309+
if run_info.token_id in placeholder_ids:
288310
yield _BoundPlaceholderRange(
289311
modality=modality,
290-
offset=run_info["start_idx"],
291-
length=run_info["length"],
312+
offset=run_info.start_idx,
313+
length=run_info.length,
292314
)
293315

294316

@@ -398,6 +420,7 @@ def find_token_matches(
398420
prompt: list[int],
399421
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
400422
) -> list[_PromptReplacementTokenMatch[_T]]:
423+
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
401424
return [
402425
_PromptReplacementTokenMatch(prompt_repl, match)
403426
for prompt_repl in prompt_repls
@@ -409,17 +432,22 @@ def find_text_matches(
409432
prompt: str,
410433
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
411434
) -> list[_PromptReplacementTextMatch[_T]]:
435+
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
412436
return [
413437
_PromptReplacementTextMatch(prompt_repl, match)
414438
for prompt_repl in prompt_repls
415439
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
416440
]
417441

418442

419-
def unique_sort_matches(
443+
def resolve_matches(
420444
prompt: _S,
421445
matches: Sequence[_PromptReplacementMatch[_T, _S]],
422446
) -> list[_PromptReplacementMatch[_T, _S]]:
447+
"""
448+
Resolve :code:`matches` to ensure that there are no overlapping matches,
449+
and sort them such that earlier matches take priority over later ones.
450+
"""
423451
num_matches_by_idx = np.zeros(len(prompt), dtype=int)
424452
for match in matches:
425453
num_matches_by_idx[match.start_idx:match.end_idx] += 1
@@ -443,8 +471,7 @@ def _replace_matches(
443471
prev_end_idx = 0
444472
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
445473

446-
# Earlier matches take priority over later ones
447-
for match in unique_sort_matches(prompt, matches):
474+
for match in resolve_matches(prompt, matches):
448475
modality = match.modality
449476
mm_items = mm_items_by_modality[modality]
450477

@@ -471,6 +498,7 @@ def replace_token_matches(
471498
mm_items_by_modality: Mapping[str, list[_T]],
472499
hf_inputs: BatchFeature,
473500
) -> list[int]:
501+
"""Apply :code:`prompt_repls` to :code:`prompt`."""
474502
if not matches:
475503
return prompt
476504

@@ -490,6 +518,7 @@ def replace_text_matches(
490518
mm_items_by_modality: Mapping[str, list[_T]],
491519
hf_inputs: BatchFeature,
492520
) -> str:
521+
"""Apply :code:`prompt_repls` to :code:`prompt`."""
493522
if not matches:
494523
return prompt
495524

@@ -581,8 +610,7 @@ def _apply_prompt_replacements(
581610

582611
if all(
583612
len(matches) >= len(mm_data[modality])
584-
for modality, matches in full_groupby(token_matches,
585-
key=lambda x: x.modality)
613+
for modality, matches in full_groupby_modality(token_matches)
586614
): # yapf: disable
587615
token_ids = replace_token_matches(
588616
token_ids,
@@ -647,11 +675,10 @@ def apply(
647675

648676
mm_placeholders = {
649677
modality: [
650-
PlaceholderRange(offset=item["offset"], length=item["length"])
678+
PlaceholderRange(offset=item.offset, length=item.length)
651679
for item in items
652680
]
653-
for modality, items in full_groupby(all_placeholders,
654-
key=lambda x: x["modality"])
681+
for modality, items in full_groupby_modality(all_placeholders)
655682
}
656683

657684
return MultiModalInputsV2(

0 commit comments

Comments
 (0)