Skip to content

Commit 23c1b10

Browse files
[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#11674)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent a115ac4 commit 23c1b10

File tree

3 files changed

+151
-168
lines changed

3 files changed

+151
-168
lines changed

vllm/multimodal/inputs.py

Lines changed: 116 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from collections import UserDict, defaultdict
33
from collections.abc import Mapping, Sequence
44
from dataclasses import dataclass
5-
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
5+
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
6+
final)
67

78
import numpy as np
89
import torch
@@ -11,7 +12,7 @@
1112
from transformers import BatchFeature
1213
from typing_extensions import NotRequired, TypeAlias
1314

14-
from vllm.utils import JSONTree, is_list_of, json_map_leaves
15+
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
1516

1617
_T = TypeVar("_T")
1718

@@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
160161

161162

162163
@dataclass(frozen=True)
163-
class MultiModalFieldItem:
164-
"""
165-
Contains metadata and data in :class:`MultiModalKwargs`
166-
corresponding to a data item in :class:`MultiModalDataItems`.
167-
"""
164+
class MultiModalFieldElem:
165+
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
168166
field: "BaseMultiModalField"
169167
data: NestedTensors
170168

@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
186184
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
187185
raise NotImplementedError
188186

189-
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
190-
return MultiModalFieldItem(self, data)
187+
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
188+
return MultiModalFieldElem(self, data)
191189

192-
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
193-
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
190+
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
191+
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
194192
fields = [item.field for item in batch]
195193
if len(set(fields)) > 1:
196194
raise ValueError(f"Cannot merge different {fields=}")
197195

198196
data = self._reduce_data([item.data for item in batch])
199197

200-
return self._build_item(data)
198+
return self._build_elem(data)
201199

202200

203201
@dataclass(frozen=True)
204202
class MultiModalBatchedField(BaseMultiModalField):
205203
"""
206-
A :class:`BaseMultiModalField` implementation where an item is obtained by
207-
directly indexing into the first dimension of the underlying data.
204+
A :class:`BaseMultiModalField` implementation where an element in the batch
205+
is obtained by indexing into the first dimension of the underlying data.
208206
"""
209207

210-
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
211-
return [self._build_item(item) for item in batch]
208+
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
209+
return [self._build_elem(item) for item in batch]
212210

213211
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
214212
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
215213
first_shape = batch[0].shape
216-
if all(item.shape == first_shape for item in batch):
214+
if all(elem.shape == first_shape for elem in batch):
217215
return torch.stack(batch)
218216

219217
return batch
@@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
222220
@dataclass(frozen=True)
223221
class MultiModalFlatField(BaseMultiModalField):
224222
"""
225-
A :class:`BaseMultiModalField` implementation where an item is obtained by
226-
slicing along the first dimension of the underlying data.
223+
A :class:`BaseMultiModalField` implementation where an element in the batch
224+
is obtained by slicing along the first dimension of the underlying data.
227225
"""
228226

229-
def build_items(
227+
def build_elems(
230228
self,
231229
batch: NestedTensors,
232230
slices: Sequence[slice],
233-
) -> list[MultiModalFieldItem]:
234-
return [self._build_item(batch[slice_]) for slice_ in slices]
231+
) -> list[MultiModalFieldElem]:
232+
return [self._build_elem(batch[slice_]) for slice_ in slices]
235233

236234
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
237235
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
238236
first_shape = batch[0].shape
239-
if all(item.shape[1:] == first_shape[1:] for item in batch):
237+
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
240238
return torch.concat(batch)
241239

242-
return [elem for item in batch for elem in item]
240+
return [e for elem in batch for e in elem]
243241

244242

245243
class MultiModalFieldConfig:
@@ -267,115 +265,111 @@ def __init__(
267265
) -> None:
268266
super().__init__()
269267

270-
self._field_cls = field_cls
271-
self._modality = modality
272-
self._field_config = field_config
268+
self.field_cls = field_cls
269+
self.modality = modality
270+
self.field_config = field_config
273271

274-
def build_items(
272+
def build_elems(
275273
self,
276274
key: str,
277275
batch: NestedTensors,
278-
) -> list[MultiModalFieldItem]:
279-
field = self._field_cls(key=key, modality=self._modality)
280-
return field.build_items(batch, **self._field_config) # type: ignore
276+
) -> Sequence[MultiModalFieldElem]:
277+
field = self.field_cls(key=key, modality=self.modality)
278+
return field.build_elems(batch, **self.field_config) # type: ignore
281279

282280

283-
class MultiModalKwargs(UserDict[str, NestedTensors]):
281+
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
282+
"""
283+
A collection of :class:`MultiModalFieldElem`
284+
corresponding to a data item in :class:`MultiModalDataItems`.
284285
"""
285-
A dictionary that represents the keyword arguments to
286-
:meth:`~torch.nn.Module.forward`.
287286

288-
The metadata :code:`items_by_key` defines how to split batched keyword
289-
arguments corresponding to each data item in :class:`MultiModalDataItems`:
287+
@staticmethod
288+
def from_elems(elems: Sequence[MultiModalFieldElem]):
289+
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
290290

291-
- For a keyword argument, we can access the :code:`i` th item in the batch
292-
via :code:`items_by_key[key][i]`.
293-
- We can gather the keyword arguments belonging to a modality by finding
294-
the keys with items that belong to that modality, then accessing
295-
the :code:`i` th item in the batch for each such key.
291+
@property
292+
def modality(self) -> str:
293+
modalities = {elem.field.modality for elem in self.data.values()}
294+
assert len(modalities) == 1, f"Found different modalities={modalities}"
295+
return next(iter(modalities))
296296

297-
Example:
298297

299-
.. code-block:: python
300-
301-
# All items belong to the "image" modality
302-
items_by_key={
303-
"pixel_values": [a, b, c, d], # "image" modality
304-
"image_grid_thw": [e, f, g, h], # "image" modality
305-
"pixel_values_video": [h, i, j], # "video" modality
306-
"video_grid_thw": [k, l, m], # "video" modality
307-
}
298+
# NOTE: UserDict is for V0 compatibility.
299+
# V1 should access individual items via `get_item`.
300+
class MultiModalKwargs(UserDict[str, NestedTensors]):
301+
"""
302+
A dictionary that represents the keyword arguments to
303+
:meth:`~torch.nn.Module.forward`.
308304
309-
- The keyword arguments belonging to the first image are
310-
:code:`{"pixel_values": a, "image_grid_thw": e}`.
311-
- The keyword arguments belonging to the second video are
312-
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
305+
The metadata :code:`items` enables us to obtain the keyword arguments
306+
corresponding to each data item in :class:`MultiModalDataItems`, via
307+
:meth:`get_item` and :meth:`get_items`.
313308
"""
314309

315310
@staticmethod
316311
def from_hf_inputs(
317312
hf_inputs: BatchFeature,
318313
config_by_key: Mapping[str, MultiModalFieldConfig],
319-
*,
320-
enable_sanity_checks: bool = False,
321314
):
322315
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
323316
# We assume that those fields are not used in vLLM
324-
items_by_key = {
325-
key: config.build_items(key, batch)
326-
for key, config in config_by_key.items()
327-
if (batch := hf_inputs.get(key)) is not None
328-
}
329-
330-
return MultiModalKwargs.from_items_by_key(
331-
items_by_key,
332-
enable_sanity_checks=enable_sanity_checks,
333-
)
317+
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
318+
keys_by_modality = defaultdict[str, set[str]](set)
319+
for key, config in config_by_key.items():
320+
batch = hf_inputs.get(key)
321+
if batch is not None:
322+
elems = config.build_elems(key, batch)
323+
if len(elems) > 0:
324+
elems_by_key[key] = elems
325+
keys_by_modality[config.modality].add(key)
326+
327+
items = list[MultiModalKwargsItem]()
328+
for modality, keys in keys_by_modality.items():
329+
elems_in_modality = {k: elems_by_key[k] for k in keys}
330+
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
331+
332+
if len(set(batch_sizes.values())) > 1:
333+
raise ValueError(
334+
f"Cannot merge different batch sizes for {modality=}! "
335+
f"Found: {batch_sizes=}")
336+
337+
batch_size = next(iter(batch_sizes.values()))
338+
for item_idx in range(batch_size):
339+
elems = [v[item_idx] for v in elems_in_modality.values()]
340+
items.append(MultiModalKwargsItem.from_elems(elems))
341+
342+
return MultiModalKwargs.from_items(items)
334343

335344
@staticmethod
336-
def from_items_by_key(
337-
items_by_key: Mapping[str, list[MultiModalFieldItem]],
338-
*,
339-
enable_sanity_checks: bool = False,
340-
) -> "MultiModalKwargs":
345+
def from_items(items: Sequence[MultiModalKwargsItem]):
346+
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
347+
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
348+
for item in items:
349+
for key, elem in item.items():
350+
elems_by_key[key].append(elem)
351+
341352
data = {
342-
key: items[0].field.reduce(items).data
343-
for key, items in items_by_key.items() if len(items) > 0
353+
key: elems[0].field.reduce(elems).data
354+
for key, elems in elems_by_key.items() if len(elems) > 0
344355
}
345356

346-
return MultiModalKwargs(data,
347-
items_by_key=items_by_key,
348-
enable_sanity_checks=enable_sanity_checks)
357+
return MultiModalKwargs(data, items=items)
349358

350359
def __init__(
351360
self,
352361
data: Mapping[str, NestedTensors],
353362
*,
354-
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
355-
enable_sanity_checks: bool = False,
363+
items: Optional[Sequence[MultiModalKwargsItem]] = None,
356364
) -> None:
357365
super().__init__(data)
358366

359-
# Shallow copy to avoid footgun in case a defaultdict is passed in
360-
self._items_by_key = dict(items_by_key)
367+
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
368+
self._items_by_modality = dict(items_by_modality)
361369

362-
keys_by_modality = defaultdict[str, set[str]](set)
363-
for key, items in items_by_key.items():
364-
for item in items:
365-
keys_by_modality[item.field.modality].add(key)
366-
367-
self._keys_by_modality = dict(keys_by_modality)
368-
369-
if enable_sanity_checks:
370-
for modality, keys in keys_by_modality.items():
371-
items_in_modality = {k: items_by_key[k] for k in keys}
372-
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
373-
batch_size = next(iter(batch_sizes.values()), 0)
374-
assert all(bs == batch_size
375-
for bs in batch_sizes.values()), dict(
376-
modality=modality,
377-
batch_sizes=batch_sizes,
378-
items_by_key=items_by_key)
370+
@property
371+
def modalities(self):
372+
return self._items_by_modality.keys()
379373

380374
@staticmethod
381375
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
@@ -452,58 +446,44 @@ def as_kwargs(
452446
def __eq__(self, other: object) -> bool:
453447
if not isinstance(other, self.__class__):
454448
return False
455-
if self._items_by_key != other._items_by_key:
449+
if self._items_by_modality != other._items_by_modality:
456450
return False
457451

458452
ks = self.keys()
459453
return (ks == other.keys()
460454
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
461455

462-
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
463-
return self._items_by_key[key][item_index]
456+
def _validate_modality(self, method_name: str, modality: str) -> None:
457+
if not self._items_by_modality:
458+
raise RuntimeError(
459+
f"`{method_name}` is not supported when "
460+
"MultiModalKwargs is not initialized with `items`")
464461

465-
def get_items_by_modality(
466-
self,
467-
modality: str,
468-
item_index: int,
469-
) -> Mapping[str, MultiModalFieldItem]:
470-
"""
471-
Get the keyword arguments corresponding to an item identified by
472-
its modality and index.
473-
"""
474-
if modality not in self._keys_by_modality:
475-
available_modalities = set(self._keys_by_modality.keys())
462+
if modality not in self._items_by_modality:
463+
available_modalities = set(self._items_by_modality.keys())
476464
raise KeyError(f"Modality {modality!r} not found. "
477465
f"Available modalities: {available_modalities}")
478466

479-
keys_to_gather = self._keys_by_modality[modality]
467+
def get_item_count(self, modality: str) -> int:
468+
"""Get the number of items belonging to a modality."""
469+
self._validate_modality("get_item_count", modality)
470+
return len(self._items_by_modality[modality])
480471

481-
return {
482-
key: self.get_item(key, item_index)
483-
for key in keys_to_gather if key in self
484-
}
472+
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
473+
"""
474+
Get the keyword arguments corresponding to an item identified by
475+
its modality and index.
476+
"""
477+
self._validate_modality("get_item", modality)
478+
return self._items_by_modality[modality][item_index]
485479

486-
@staticmethod
487-
def from_items_by_modality(
488-
items_by_modality: Mapping[str, list[Mapping[str,
489-
MultiModalFieldItem]]],
490-
*,
491-
enable_sanity_checks: bool = False,
492-
) -> "MultiModalKwargs":
480+
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
493481
"""
494-
Construct a new :class:`MultiModalKwargs` from multiple items returned
495-
by :meth:`get_fields_by_modality`.
482+
Get the keyword arguments corresponding to each item belonging to
483+
a modality.
496484
"""
497-
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
498-
for fields in items_by_modality.values():
499-
for field in fields:
500-
for k, v in field.items():
501-
items_by_key[k].append(v)
502-
503-
return MultiModalKwargs.from_items_by_key(
504-
items_by_key,
505-
enable_sanity_checks=enable_sanity_checks,
506-
)
485+
self._validate_modality("get_items", modality)
486+
return self._items_by_modality[modality]
507487

508488

509489
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]

0 commit comments

Comments
 (0)