Skip to content

Commit dbc0538

Browse files
DarkLight1337Isotr0py
authored andcommitted
[VLM] Simplify post-processing of replacement info (vllm-project#12269)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent c9510ee commit dbc0538

File tree

10 files changed

+175
-208
lines changed

10 files changed

+175
-208
lines changed

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _test_processing_correctness(
3535
task="auto",
3636
tokenizer=model_id,
3737
tokenizer_mode="auto",
38-
trust_remote_code=True,
38+
trust_remote_code=model_info.trust_remote_code,
3939
seed=0,
4040
dtype="float16",
4141
revision=None,

tests/models/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def check_available_online(
261261
trust_remote_code=True),
262262
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
263263
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
264-
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
264+
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
265+
trust_remote_code=True),
265266
# [Encoder-decoder]
266267
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
267268
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501

tests/multimodal/test_processing.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
from vllm.config import ModelConfig
99
from vllm.multimodal import MULTIMODAL_REGISTRY
10-
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement,
10+
# yapf conflicts with isort for this block
11+
# yapf: disable
12+
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
13+
PromptReplacement,
1114
find_mm_placeholders,
1215
find_text_matches, find_token_matches,
1316
iter_token_matches,
1417
replace_text_matches,
1518
replace_token_matches)
19+
# yapf: enable
1620
from vllm.multimodal.profiling import MultiModalProfiler
1721
from vllm.multimodal.utils import cached_get_tokenizer
1822
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -433,19 +437,19 @@ def test_find_replace_tokens(
433437
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
434438
{
435439
"pattern_1": [
436-
PlaceholderInfo(
440+
PlaceholderFeaturesInfo(
437441
modality="pattern_1",
438442
item_idx=0,
439443
start_idx=6,
440-
replacement=[32000, 32000],
444+
tokens=[32000, 32000],
441445
),
442446
],
443447
"pattern_4": [
444-
PlaceholderInfo(
448+
PlaceholderFeaturesInfo(
445449
modality="pattern_4",
446450
item_idx=0,
447451
start_idx=3,
448-
replacement=[32000],
452+
tokens=[32000],
449453
),
450454
],
451455
}
@@ -455,25 +459,25 @@ def test_find_replace_tokens(
455459
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
456460
{
457461
"pattern_1": [
458-
PlaceholderInfo(
462+
PlaceholderFeaturesInfo(
459463
modality="pattern_1",
460464
item_idx=0,
461465
start_idx=1,
462-
replacement=[32000, 32000],
466+
tokens=[32000, 32000],
463467
),
464-
PlaceholderInfo(
468+
PlaceholderFeaturesInfo(
465469
modality="pattern_1",
466470
item_idx=1,
467471
start_idx=5,
468-
replacement=[32000, 32000],
472+
tokens=[32000, 32000],
469473
),
470474
],
471475
"pattern_3": [
472-
PlaceholderInfo(
476+
PlaceholderFeaturesInfo(
473477
modality="pattern_3",
474478
item_idx=0,
475479
start_idx=7,
476-
replacement=[1550, 918, 1550],
480+
tokens=[1550, 918, 1550],
477481
),
478482
],
479483
# No match for pattern_4 as it has lower priority than pattern_1
@@ -483,33 +487,33 @@ def test_find_replace_tokens(
483487
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
484488
{
485489
"pattern_1": [
486-
PlaceholderInfo(
490+
PlaceholderFeaturesInfo(
487491
modality="pattern_1",
488492
item_idx=0,
489493
start_idx=1,
490-
replacement=[32000, 32000],
494+
tokens=[32000, 32000],
491495
),
492-
PlaceholderInfo(
496+
PlaceholderFeaturesInfo(
493497
modality="pattern_1",
494498
item_idx=1,
495499
start_idx=3,
496-
replacement=[32000, 32000],
500+
tokens=[32000, 32000],
497501
),
498502
],
499503
"pattern_4": [
500-
PlaceholderInfo(
504+
PlaceholderFeaturesInfo(
501505
modality="pattern_4",
502506
item_idx=0,
503507
start_idx=5,
504-
replacement=[32000],
508+
tokens=[32000],
505509
),
506510
],
507511
"pattern_3": [
508-
PlaceholderInfo(
512+
PlaceholderFeaturesInfo(
509513
modality="pattern_3",
510514
item_idx=0,
511515
start_idx=6,
512-
replacement=[1550, 918, 1550],
516+
tokens=[1550, 918, 1550],
513517
),
514518
],
515519
}

vllm/model_executor/models/aria.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,7 @@ def get_vision_config(self):
342342
return self.get_hf_config().vision_config
343343

344344
def get_hf_processor(self):
345-
processor = self.ctx.get_hf_processor(AriaProcessor)
346-
347-
# Patch for https://github.com/huggingface/transformers/issues/35768
348-
processor.tokenizer.image_token = "<|img|>"
349-
processor.image_token = "<|img|>"
350-
351-
return processor
345+
return self.ctx.get_hf_processor(AriaProcessor)
352346

353347
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
354348
return {"image": None}
@@ -381,7 +375,7 @@ def get_dummy_processor_inputs(
381375
}
382376

383377
hf_processor = self.info.get_hf_processor()
384-
image_token: str = hf_processor.image_token # type: ignore
378+
image_token: str = hf_processor.tokenizer.image_token # type: ignore
385379

386380
return ProcessorInputs(
387381
prompt_text=image_token * num_images,

vllm/model_executor/models/blip2.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1515
from vllm.model_executor.sampling_metadata import SamplingMetadata
1616
from vllm.multimodal import MULTIMODAL_REGISTRY
17-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
18-
MultiModalInputs, MultiModalKwargs,
19-
NestedTensors, PlaceholderRange)
17+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
18+
NestedTensors)
2019
from vllm.multimodal.parse import MultiModalDataItems
2120
from vllm.multimodal.processing import (BaseMultiModalProcessor,
22-
BaseProcessingInfo, PromptReplacement)
21+
BaseProcessingInfo, PromptReplacement,
22+
PromptReplacementDetails)
2323
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
2424
from vllm.sequence import IntermediateTensors
2525

@@ -481,30 +481,13 @@ def _get_prompt_replacements(
481481
PromptReplacement(
482482
modality="image",
483483
target="</s>",
484-
replacement="<image>" * num_image_tokens + "</s>",
484+
replacement=PromptReplacementDetails(
485+
full="<image>" * num_image_tokens + "</s>",
486+
features="<image>" * num_image_tokens,
487+
),
485488
)
486489
]
487490

488-
def apply(
489-
self,
490-
prompt: Union[str, list[int]],
491-
mm_data: MultiModalDataDict,
492-
hf_processor_mm_kwargs: Mapping[str, object],
493-
) -> MultiModalInputs:
494-
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
495-
496-
# Only <image> tokens should be considered as placeholders,
497-
# so we ignore the trailing bos_token
498-
result["mm_placeholders"] = {
499-
modality: [
500-
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
501-
for p in ps
502-
]
503-
for modality, ps in result["mm_placeholders"].items()
504-
}
505-
506-
return result
507-
508491

509492
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
510493
info=Blip2ProcessingInfo,

vllm/model_executor/models/chameleon.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from vllm.model_executor.sampling_metadata import SamplingMetadata
2929
from vllm.model_executor.utils import set_weight_attrs
3030
from vllm.multimodal import MULTIMODAL_REGISTRY
31-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
32-
MultiModalInputs, MultiModalKwargs,
33-
NestedTensors, PlaceholderRange)
31+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
32+
NestedTensors)
3433
from vllm.multimodal.parse import MultiModalDataItems
3534
from vllm.multimodal.processing import (BaseMultiModalProcessor,
36-
BaseProcessingInfo, PromptReplacement)
35+
BaseProcessingInfo, PromptReplacement,
36+
PromptReplacementDetails)
3737
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
3838
from vllm.sequence import IntermediateTensors
3939

@@ -141,39 +141,23 @@ def _get_prompt_replacements(
141141
out_mm_kwargs: MultiModalKwargs,
142142
) -> list[PromptReplacement]:
143143
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
144+
image_tokens = processor.image_token * self.info.get_num_image_tokens()
144145

145146
return [
146147
PromptReplacement(
147148
modality="image",
148149
target="<image>",
149-
replacement="".join([
150-
processor.image_start_token,
151-
processor.image_token * self.info.get_num_image_tokens(),
152-
processor.image_end_token,
153-
]),
150+
replacement=PromptReplacementDetails(
151+
full="".join([
152+
processor.image_start_token,
153+
image_tokens,
154+
processor.image_end_token,
155+
]),
156+
features=image_tokens,
157+
),
154158
)
155159
]
156160

157-
def apply(
158-
self,
159-
prompt: Union[str, list[int]],
160-
mm_data: MultiModalDataDict,
161-
hf_processor_mm_kwargs: Mapping[str, object],
162-
) -> MultiModalInputs:
163-
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
164-
165-
# Only <image> tokens should be considered as placeholders,
166-
# so we ignore the image_start_token and image_end_token
167-
result["mm_placeholders"] = {
168-
modality: [
169-
PlaceholderRange(offset=p["offset"] + 1,
170-
length=p["length"] - 2) for p in ps
171-
]
172-
for modality, ps in result["mm_placeholders"].items()
173-
}
174-
175-
return result
176-
177161

178162
class ChameleonLayerNorm(nn.LayerNorm):
179163

vllm/model_executor/models/fuyu.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
""" PyTorch Fuyu model."""
1717
import math
1818
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
19-
TypedDict, Union)
19+
TypedDict)
2020

2121
import torch
2222
import torch.nn as nn
@@ -30,13 +30,13 @@
3030
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
3131
from vllm.model_executor.sampling_metadata import SamplingMetadata
3232
from vllm.multimodal import MULTIMODAL_REGISTRY
33-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34-
MultiModalInputs, MultiModalKwargs,
35-
NestedTensors, PlaceholderRange)
33+
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
34+
NestedTensors)
3635
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
3736
MultiModalDataItems)
3837
from vllm.multimodal.processing import (BaseMultiModalProcessor,
39-
BaseProcessingInfo, PromptReplacement)
38+
BaseProcessingInfo, PromptReplacement,
39+
PromptReplacementDetails)
4040
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
4141
from vllm.sequence import IntermediateTensors
4242

@@ -215,9 +215,13 @@ def get_replacement_fuyu(item_idx: int):
215215
image_width=image_size.width,
216216
image_height=image_size.height,
217217
)
218+
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
219+
[_NEWLINE_TOKEN_ID]) * nrows
218220

219-
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
220-
[bos_token_id])
221+
return PromptReplacementDetails(
222+
full=image_tokens + [bos_token_id],
223+
features=image_tokens,
224+
)
221225

222226
return [
223227
PromptReplacement(
@@ -227,26 +231,6 @@ def get_replacement_fuyu(item_idx: int):
227231
)
228232
]
229233

230-
def apply(
231-
self,
232-
prompt: Union[str, list[int]],
233-
mm_data: MultiModalDataDict,
234-
hf_processor_mm_kwargs: Mapping[str, object],
235-
) -> MultiModalInputs:
236-
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
237-
238-
# Only |SPEAKER| (image) tokens should be considered as placeholders,
239-
# so we ignore the trailing bos_token_id
240-
result["mm_placeholders"] = {
241-
modality: [
242-
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
243-
for p in ps
244-
]
245-
for modality, ps in result["mm_placeholders"].items()
246-
}
247-
248-
return result
249-
250234

251235
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
252236
info=FuyuProcessingInfo,

0 commit comments

Comments
 (0)