Skip to content

Commit b559fa6

Browse files
ywang96tjtanaa
authored andcommitted
[V1][Bugfix] Fix data item ordering in mixed-modality inference (vllm-project#12259)
Signed-off-by: Roger Wang <[email protected]>
1 parent 29b95c6 commit b559fa6

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

vllm/multimodal/utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import lru_cache
2+
from itertools import groupby
23
from pathlib import Path
34
from typing import TYPE_CHECKING, Optional, TypeVar, Union
45
from urllib.parse import ParseResult, urlparse
@@ -26,7 +27,7 @@
2627

2728
if TYPE_CHECKING:
2829
from .hasher import MultiModalHashDict
29-
from .inputs import MultiModalPlaceholderDict
30+
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
3031

3132

3233
class MediaConnector:
@@ -477,3 +478,34 @@ def merge_and_sort_multimodal_metadata(
477478
merged_hashes = None
478479

479480
return sorted_modalities, merged_placeholders, merged_hashes
481+
482+
483+
def group_mm_inputs_by_modality(
484+
mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
485+
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
486+
together into the same list for batching purpose. For MultiModalKwargs with
487+
multiple modalities, put them into their own list.
488+
489+
Args:
490+
mm_inputs: List of MultiModalKwargs.
491+
492+
Returns:
493+
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
494+
inner list contains consecutive MultiModalKwargs with same modality, or
495+
one with multimodal modalities.
496+
"""
497+
if not mm_inputs:
498+
return []
499+
500+
def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]:
501+
# If the input has multiple modalities, return a id as the unique key
502+
# for the mm_input input.
503+
if len(mm_input.modalities) > 1:
504+
return id(mm_input)
505+
506+
# Otherwise return the modality string
507+
return list(mm_input.modalities)[0]
508+
509+
return [
510+
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
511+
]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
1818
from vllm.model_executor.model_loader import get_model
1919
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
20+
from vllm.multimodal.utils import group_mm_inputs_by_modality
2021
from vllm.sampling_params import SamplingType
2122
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
2223
LayerBlockType, cdiv, is_pin_memory_available)
@@ -629,19 +630,34 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
629630
for input_id in encoder_input_ids:
630631
mm_inputs.append(req_state.mm_inputs[input_id])
631632
req_input_ids.append((req_id, input_id))
632-
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
633-
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
634-
device=self.device)
635-
636-
# Run the encoder.
637-
# `encoder_outputs` is either of the following:
638-
# 1. A tensor of shape [num_images, feature_size, hidden_size]
639-
# in case when feature_size is fixed across all images.
640-
# 2. A list (length: num_images) of tensors, each of shape
641-
# [feature_size, hidden_size] in case when the feature size is
642-
# dynamic depending on input images.
643-
encoder_outputs = self.model.get_multimodal_embeddings(
644-
**batched_mm_inputs)
633+
634+
# Batch mm inputs as much as we can: if a request in the batch has
635+
# multiple modalities or a different modality than the previous one,
636+
# we process it separately to preserve item order.
637+
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
638+
# in the same batch while still being able to benefit from batching
639+
# multimodal inputs. The proper solution should be reordering the
640+
# encoder outputs.
641+
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
642+
643+
encoder_outputs = []
644+
for grouped_mm_inputs in grouped_mm_inputs_list:
645+
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
646+
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
647+
device=self.device)
648+
649+
# Run the encoder.
650+
# `curr_group_outputs` is either of the following:
651+
# 1. A tensor of shape (num_items, feature_size, hidden_size)
652+
# in case feature_size is fixed across all multimodal items.
653+
# 2. A list or tuple (length: num_items) of tensors, each of shape
654+
# (feature_size, hidden_size) in case the feature size is dynamic
655+
# depending on the input multimodal items.
656+
curr_group_outputs = self.model.get_multimodal_embeddings(
657+
**batched_mm_inputs)
658+
659+
for output in curr_group_outputs:
660+
encoder_outputs.append(output)
645661

646662
# Cache the encoder outputs.
647663
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):

0 commit comments

Comments
 (0)