Skip to content

Commit 4861d54

Browse files
committed
Provide necessary data for replacement
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 4b23817 commit 4861d54

File tree

3 files changed

+102
-41
lines changed

3 files changed

+102
-41
lines changed

tests/multimodal/test_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_token_match_by_text(
104104
or match_str in tokenizer.decode(token_ids,
105105
skip_special_tokens=False)):
106106
assert match is not None
107-
match_start_idx, match_end_idx = match
107+
match_start_idx, match_end_idx, *_ = match
108108

109109
assert match_str in tokenizer.decode(
110110
token_ids[match_start_idx:match_end_idx],

vllm/multimodal/inputs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
203203
"""The type of inputs."""
204204

205205
prompt: str
206-
"""
207-
The original, unprocessed prompt text.
208-
209-
Note:
210-
Since prompt text is not required by vLLM internals, we leave this
211-
unprocessed to save CPU computation. You can still call
212-
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
213-
"""
206+
"""The processed prompt text."""
214207

215208
prompt_token_ids: List[int]
216209
"""The processed token IDs which includes placeholder tokens."""

vllm/multimodal/processing.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass
22
from functools import lru_cache
3-
from heapq import nsmallest
43
from itertools import groupby
54
from typing import (Any, Callable, Generic, List, Mapping, NamedTuple,
65
Optional, TypeVar, Union, final)
@@ -147,6 +146,9 @@ def _encode(
147146
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
148147

149148

149+
_cached_encode = lru_cache(_encode)
150+
151+
150152
@lru_cache
151153
def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int:
152154
return max(len(token_text) for token_text in tokenizer.get_vocab())
@@ -157,7 +159,10 @@ class _TokenMatch(NamedTuple):
157159
end_idx: int
158160

159161

160-
def find_token_match(token_ids: List[int], match_ids: List[int]):
162+
def find_token_match(
163+
token_ids: List[int],
164+
match_ids: List[int],
165+
) -> Optional[_TokenMatch]:
161166
"""
162167
Find the first occurrence of :code:`match_ids` in :code:`token_ids`.
163168
"""
@@ -171,25 +176,49 @@ def find_token_match(token_ids: List[int], match_ids: List[int]):
171176
return None
172177

173178

174-
class _Candidate(NamedTuple):
179+
class _TokenMatchFromTextCandidate(NamedTuple):
175180
start_idx: int
176181
end_idx: int
177-
distance: int
182+
183+
match_text_prefix: str
184+
match_text_suffix: str
185+
186+
@property
187+
def distance(self) -> int:
188+
return len(self.match_text_prefix) + len(self.match_text_suffix)
189+
190+
191+
class _TokenMatchFromText(NamedTuple):
192+
start_idx: int
193+
end_idx: int
194+
195+
match_prefix: List[int]
196+
match_suffix: List[int]
197+
198+
match_text_prefix: str
199+
match_text_suffix: str
178200

179201

180202
def find_token_match_by_text(
181203
tokenizer: AnyTokenizer,
182204
token_ids: List[int],
183205
token_text: str,
184206
match_text: str,
185-
):
207+
) -> Optional[_TokenMatchFromText]:
186208
"""
187209
Find the first occurrence of the tokenized :code:`match_text` in
188210
:code:`token_ids`.
189211
"""
190-
match_ids = _encode(tokenizer, match_text, add_special_tokens=False)
212+
match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False)
191213
if (match := find_token_match(token_ids, match_ids)):
192-
return match
214+
return _TokenMatchFromText(
215+
match.start_idx,
216+
match.end_idx,
217+
match_prefix=[],
218+
match_suffix=[],
219+
match_text_prefix="",
220+
match_text_suffix="",
221+
)
193222

194223
# When `match_text` is not mapped to a special token ID,
195224
# it may be tokenized differently based on the surrounding tokens
@@ -202,37 +231,41 @@ def find_token_match_by_text(
202231
text_end_idx = text_start_idx + len(match_text)
203232

204233
# In case the left/right side of `match_text` is fused with the
205-
# string immediately before/after it during tokenization
234+
# string immediately before/after it as a single token
206235
text_buffer = _max_vocab_token_len(tokenizer) - 1
207236
left_text = token_text[:max(0, text_start_idx - text_buffer)]
208237
right_text = token_text[:text_end_idx + text_buffer]
209238

210239
left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False))
211240
right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True))
212-
avg_idx = (left_idx + right_idx) // 2
213241
window_size = len(match_ids)
214242

215-
valid_candidates = list[_Candidate]()
216-
for start_idx in sorted(range(left_idx, right_idx - window_size + 1),
217-
key=lambda x: abs(x - avg_idx)):
243+
best_distance = len(token_text)
244+
best_candidate = None
245+
246+
for start_idx in range(left_idx, right_idx - window_size + 1):
218247
end_idx = start_idx + window_size
219248
candidate_text = tokenizer.decode(
220249
token_ids[start_idx:end_idx],
250+
# In case match_text is a special token
221251
skip_special_tokens=False,
222252
)
223253

224254
if match_text in candidate_text:
225-
candidate = _Candidate(
226-
start_idx=start_idx,
227-
end_idx=end_idx,
228-
distance=len(candidate_text) - len(match_text),
255+
candidate = _TokenMatchFromTextCandidate(
256+
start_idx,
257+
end_idx,
258+
*candidate_text.split(match_text, 1),
229259
)
230-
valid_candidates.append(candidate)
231260

232-
if candidate.distance == 0:
261+
if candidate.distance < best_distance:
262+
best_candidate = candidate
263+
best_distance = candidate.distance
264+
265+
if best_distance == 0:
233266
break
234267

235-
assert len(valid_candidates) > 0, dict(
268+
assert best_candidate is not None, dict(
236269
# To facilitate debugging
237270
token_ids=token_ids,
238271
match_ids=match_ids,
@@ -242,8 +275,25 @@ def find_token_match_by_text(
242275
right_idx=right_idx,
243276
)
244277

245-
best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance)
246-
return best_candidate.start_idx, best_candidate.end_idx
278+
match_token_prefix = _cached_encode(
279+
tokenizer,
280+
best_candidate.match_text_prefix,
281+
add_special_tokens=False,
282+
)
283+
match_token_suffix = _cached_encode(
284+
tokenizer,
285+
best_candidate.match_text_suffix,
286+
add_special_tokens=False,
287+
)
288+
289+
return _TokenMatchFromText(
290+
start_idx=best_candidate.start_idx,
291+
end_idx=best_candidate.end_idx,
292+
match_prefix=match_token_prefix,
293+
match_suffix=match_token_suffix,
294+
match_text_prefix=best_candidate.match_text_prefix,
295+
match_text_suffix=best_candidate.match_text_suffix,
296+
)
247297

248298

249299
def apply_placeholders(
@@ -253,7 +303,7 @@ def apply_placeholders(
253303
match_text: str,
254304
replacement_id: int,
255305
replacement_count: int,
256-
) -> Optional[PlaceholderRange]:
306+
) -> tuple[List[int], str, Optional[PlaceholderRange]]:
257307
"""
258308
Find the first occurrence of the tokenized :code:`match_text` in
259309
:code:`token_ids`, and replace it with
@@ -269,13 +319,25 @@ def apply_placeholders(
269319
)
270320

271321
if match is None:
272-
return None
322+
return token_ids, token_text, None
323+
324+
start_idx, end_idx, prefix_ids, suffix_ids, prefix_str, suffix_str = match
273325

274-
# TODO(youkaichao): Don't update new_token_ids
275-
start_idx, end_idx = match
276-
token_ids[start_idx:end_idx] = [replacement_id] * replacement_count
326+
replacement_ids = (prefix_ids + [replacement_id] * replacement_count +
327+
suffix_ids)
328+
replacement_text = tokenizer.decode(
329+
replacement_ids,
330+
# In case match_text is a special token
331+
skip_special_tokens=False,
332+
)
333+
334+
token_ids[start_idx:end_idx] = replacement_ids
335+
token_text = token_text.replace(prefix_str + match_text + suffix_str,
336+
replacement_text, 1)
277337

278-
return PlaceholderRange(offset=start_idx, length=replacement_count)
338+
return (token_ids, token_text,
339+
PlaceholderRange(offset=start_idx + len(prefix_ids),
340+
length=replacement_count))
279341

280342

281343
class MultiModalProcessor:
@@ -318,6 +380,7 @@ def apply(
318380
new_token_ids, = processed_inputs.pop("input_ids").tolist()
319381
mm_kwargs = MultiModalKwargs(processed_inputs)
320382

383+
new_prompt = prompt
321384
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
322385

323386
for modality, orig_inputs in to_multi_format(mm_data).items():
@@ -337,8 +400,9 @@ def apply(
337400
if new_token_id in repl_token_ids:
338401
modality_placeholders.append(run_info)
339402

340-
# Otherwise, we insert them ourselves
341-
if not modality_placeholders:
403+
if modality_placeholders:
404+
new_prompt = tokenizer.decode(new_token_ids)
405+
else: # Otherwise, we insert them ourselves
342406
for item_idx, orig_item in enumerate(orig_inputs):
343407
for match_str, replacement in placeholder_repls.items():
344408
replacement_count = replacement["count"]
@@ -349,10 +413,14 @@ def apply(
349413
item_idx,
350414
)
351415

352-
placeholders = apply_placeholders(
416+
(
417+
new_token_ids,
418+
new_prompt,
419+
placeholders,
420+
) = apply_placeholders(
353421
tokenizer,
354422
new_token_ids,
355-
prompt,
423+
new_prompt,
356424
match_str,
357425
replacement["token_id"],
358426
replacement_count,
@@ -365,7 +433,7 @@ def apply(
365433

366434
return MultiModalInputsV2(
367435
type="multimodal",
368-
prompt=prompt,
436+
prompt=new_prompt,
369437
prompt_token_ids=new_token_ids,
370438
mm_kwargs=mm_kwargs,
371439
mm_placeholders=mm_placeholders,

0 commit comments

Comments
 (0)