Skip to content

Commit 365801f

Browse files
[VLM] Add max-count checking in data parser for single image models (#11661)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 4db72e5 commit 365801f

File tree

6 files changed

+48
-11
lines changed

6 files changed

+48
-11
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ
566566
- [V1](gh-issue:8779)
567567
* - `AriaForConditionalGeneration`
568568
- Aria
569-
- T + I
569+
- T + I<sup>+</sup>
570570
- `rhymes-ai/Aria`
571571
-
572572
- ✅︎

tests/multimodal/test_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,11 @@ def _test_processing_cache_correctness(
622622

623623

624624
# yapf: disable
625+
# True if the model supports multiple data items of the modality per request
625626
@pytest.mark.parametrize(("model_id", "modalities"), [
626627
("rhymes-ai/Aria", {"image": True}),
627628
("Salesforce/blip2-opt-2.7b", {"image": False}),
628-
("facebook/chameleon-7b", {"image": True}),
629+
("facebook/chameleon-7b", {"image": False}),
629630
("adept/fuyu-8b", {"image": False}),
630631
("llava-hf/llava-1.5-7b-hf", {"image": True}),
631632
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),

vllm/model_executor/models/blip2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
1919
MultiModalInputsV2, MultiModalKwargs,
2020
NestedTensors, PlaceholderRange)
21+
from vllm.multimodal.parse import MultiModalDataParser
2122
from vllm.multimodal.processing import (BaseMultiModalProcessor,
2223
MultiModalDataItems, ProcessorInputs,
2324
PromptReplacement)
@@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):
404405

405406
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
406407

408+
def _get_data_parser(self) -> MultiModalDataParser:
409+
return MultiModalDataParser(max_mm_counts={"image": 1})
410+
407411
def _get_hf_processor(self) -> Blip2Processor:
408412
return self.ctx.get_hf_processor(Blip2Processor)
409413

vllm/model_executor/models/chameleon.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
3232
MultiModalInputsV2, MultiModalKwargs,
3333
NestedTensors, PlaceholderRange)
34+
from vllm.multimodal.parse import MultiModalDataParser
3435
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3536
MultiModalDataItems, ProcessorInputs,
3637
PromptReplacement)
@@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
6061

6162
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
6263

64+
def _get_data_parser(self) -> MultiModalDataParser:
65+
return MultiModalDataParser(max_mm_counts={"image": 1})
66+
6367
def _get_hf_processor(self) -> ChameleonProcessor:
6468
return self.ctx.get_hf_processor(ChameleonProcessor)
6569

vllm/model_executor/models/fuyu.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
3535
MultiModalInputsV2, MultiModalKwargs,
3636
NestedTensors, PlaceholderRange)
37-
from vllm.multimodal.parse import ImageProcessorItems
37+
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
3838
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3939
MultiModalDataItems, ProcessorInputs,
4040
PromptReplacement)
@@ -54,7 +54,7 @@
5454

5555
class FuyuImagePatchInputs(TypedDict):
5656
type: Literal["image_patches"]
57-
data: torch.Tensor
57+
flat_data: torch.Tensor
5858
"""
5959
Shape:
6060
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
@@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
6363
patches_per_image: List[int]
6464
"""
6565
List of number of total patches for each image in the batch.
66-
This is used to restore the first two dimensions of `data`.
66+
This is used to restore the first two dimensions of `flat_data`.
6767
"""
6868

6969

@@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
102102

103103
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
104104

105+
def _get_data_parser(self) -> MultiModalDataParser:
106+
return MultiModalDataParser(max_mm_counts={"image": 1})
107+
105108
def _get_hf_processor(self) -> FuyuProcessor:
106109
return self.ctx.get_hf_processor(FuyuProcessor)
107110

@@ -304,7 +307,7 @@ def _parse_and_validate_image_input(
304307

305308
return FuyuImagePatchInputs(
306309
type="image_patches",
307-
data=self._validate_pixel_values(
310+
flat_data=self._validate_pixel_values(
308311
flatten_bn(image_patches_flat, concat=True)),
309312
patches_per_image=[x.size(0) for x in image_patches_flat],
310313
)
@@ -313,12 +316,13 @@ def _parse_and_validate_image_input(
313316

314317
def _process_image_input(
315318
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
316-
image_patches = image_input["data"]
319+
image_patches_flat = image_input["flat_data"]
317320
patches_per_image = image_input["patches_per_image"]
318321

319322
assert self.vision_embed_tokens is not None
320-
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
321-
return vision_embeddings.split(patches_per_image, dim=0)
323+
vision_embeddings_flat, _ = self.vision_embed_tokens(
324+
image_patches_flat)
325+
return vision_embeddings_flat.split(patches_per_image, dim=0)
322326

323327
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
324328
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/multimodal/parse.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,24 @@ def get_items(
220220
class MultiModalDataParser:
221221
"""
222222
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
223+
224+
Args:
225+
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
226+
belonging to each modality. This effectively sets a hard limit over
227+
`--limit-mm-per-prompt`.
228+
target_sr (float, optional): Enables automatic resampling of audio
229+
items to the model's expected sampling rate.
223230
"""
224231

225-
def __init__(self, *, target_sr: Optional[float] = None) -> None:
232+
def __init__(
233+
self,
234+
*,
235+
max_mm_counts: Mapping[str, int] = {},
236+
target_sr: Optional[float] = None,
237+
) -> None:
226238
super().__init__()
227239

240+
self.max_mm_counts = max_mm_counts
228241
self.target_sr = target_sr
229242

230243
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
@@ -332,13 +345,24 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
332345

333346
def parse_mm_data(self,
334347
mm_data: MultiModalDataDict) -> MultiModalDataItems:
348+
max_mm_counts = self.max_mm_counts
335349
subparsers = self._get_subparsers()
336350

337351
mm_items = MultiModalDataItems()
338352
for k, v in mm_data.items():
339353
if k not in subparsers:
340354
raise ValueError(f"Unsupported modality: {k}")
341355

342-
mm_items[k] = subparsers[k](v)
356+
modality_items = subparsers[k](v)
357+
358+
if k in max_mm_counts:
359+
max_count = max_mm_counts[k]
360+
if len(modality_items) > max_count:
361+
raise ValueError(
362+
f"This model supports at most {max_count} {k} items "
363+
f"per prompt, but {len(modality_items)} {k} items "
364+
"were given or set as its limit_mm_per_prompt.")
365+
366+
mm_items[k] = modality_items
343367

344368
return mm_items

0 commit comments

Comments
 (0)