2
2
from collections import UserDict , defaultdict
3
3
from collections .abc import Mapping , Sequence
4
4
from dataclasses import dataclass
5
- from typing import Any , Literal , TypedDict , TypeVar , Union , cast , final
5
+ from typing import (Any , Literal , Optional , TypedDict , TypeVar , Union , cast ,
6
+ final )
6
7
7
8
import numpy as np
8
9
import torch
11
12
from transformers import BatchFeature
12
13
from typing_extensions import NotRequired , TypeAlias
13
14
14
- from vllm .utils import JSONTree , is_list_of , json_map_leaves
15
+ from vllm .utils import JSONTree , full_groupby , is_list_of , json_map_leaves
15
16
16
17
_T = TypeVar ("_T" )
17
18
@@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
160
161
161
162
162
163
@dataclass (frozen = True )
163
- class MultiModalFieldItem :
164
- """
165
- Contains metadata and data in :class:`MultiModalKwargs`
166
- corresponding to a data item in :class:`MultiModalDataItems`.
167
- """
164
+ class MultiModalFieldElem :
165
+ """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
168
166
field : "BaseMultiModalField"
169
167
data : NestedTensors
170
168
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
186
184
def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
187
185
raise NotImplementedError
188
186
189
- def _build_item (self , data : NestedTensors ) -> MultiModalFieldItem :
190
- return MultiModalFieldItem (self , data )
187
+ def _build_elem (self , data : NestedTensors ) -> MultiModalFieldElem :
188
+ return MultiModalFieldElem (self , data )
191
189
192
- def reduce (self , batch : list [MultiModalFieldItem ]) -> MultiModalFieldItem :
193
- """Merge multiple instances of :class:`MultiModalFieldItem ` together."""
190
+ def reduce (self , batch : list [MultiModalFieldElem ]) -> MultiModalFieldElem :
191
+ """Merge multiple instances of :class:`MultiModalFieldElem ` together."""
194
192
fields = [item .field for item in batch ]
195
193
if len (set (fields )) > 1 :
196
194
raise ValueError (f"Cannot merge different { fields = } " )
197
195
198
196
data = self ._reduce_data ([item .data for item in batch ])
199
197
200
- return self ._build_item (data )
198
+ return self ._build_elem (data )
201
199
202
200
203
201
@dataclass (frozen = True )
204
202
class MultiModalBatchedField (BaseMultiModalField ):
205
203
"""
206
- A :class:`BaseMultiModalField` implementation where an item is obtained by
207
- directly indexing into the first dimension of the underlying data.
204
+ A :class:`BaseMultiModalField` implementation where an element in the batch
205
+ is obtained by indexing into the first dimension of the underlying data.
208
206
"""
209
207
210
- def build_items (self , batch : NestedTensors ) -> list [MultiModalFieldItem ]:
211
- return [self ._build_item (item ) for item in batch ]
208
+ def build_elems (self , batch : NestedTensors ) -> list [MultiModalFieldElem ]:
209
+ return [self ._build_elem (item ) for item in batch ]
212
210
213
211
def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
214
212
if len (batch ) > 0 and is_list_of (batch , torch .Tensor , check = "all" ):
215
213
first_shape = batch [0 ].shape
216
- if all (item .shape == first_shape for item in batch ):
214
+ if all (elem .shape == first_shape for elem in batch ):
217
215
return torch .stack (batch )
218
216
219
217
return batch
@@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
222
220
@dataclass (frozen = True )
223
221
class MultiModalFlatField (BaseMultiModalField ):
224
222
"""
225
- A :class:`BaseMultiModalField` implementation where an item is obtained by
226
- slicing along the first dimension of the underlying data.
223
+ A :class:`BaseMultiModalField` implementation where an element in the batch
224
+ is obtained by slicing along the first dimension of the underlying data.
227
225
"""
228
226
229
- def build_items (
227
+ def build_elems (
230
228
self ,
231
229
batch : NestedTensors ,
232
230
slices : Sequence [slice ],
233
- ) -> list [MultiModalFieldItem ]:
234
- return [self ._build_item (batch [slice_ ]) for slice_ in slices ]
231
+ ) -> list [MultiModalFieldElem ]:
232
+ return [self ._build_elem (batch [slice_ ]) for slice_ in slices ]
235
233
236
234
def _reduce_data (self , batch : list [NestedTensors ]) -> NestedTensors :
237
235
if len (batch ) > 0 and is_list_of (batch , torch .Tensor , check = "all" ):
238
236
first_shape = batch [0 ].shape
239
- if all (item .shape [1 :] == first_shape [1 :] for item in batch ):
237
+ if all (elem .shape [1 :] == first_shape [1 :] for elem in batch ):
240
238
return torch .concat (batch )
241
239
242
- return [elem for item in batch for elem in item ]
240
+ return [e for elem in batch for e in elem ]
243
241
244
242
245
243
class MultiModalFieldConfig :
@@ -267,115 +265,111 @@ def __init__(
267
265
) -> None :
268
266
super ().__init__ ()
269
267
270
- self ._field_cls = field_cls
271
- self ._modality = modality
272
- self ._field_config = field_config
268
+ self .field_cls = field_cls
269
+ self .modality = modality
270
+ self .field_config = field_config
273
271
274
- def build_items (
272
+ def build_elems (
275
273
self ,
276
274
key : str ,
277
275
batch : NestedTensors ,
278
- ) -> list [ MultiModalFieldItem ]:
279
- field = self ._field_cls (key = key , modality = self ._modality )
280
- return field .build_items (batch , ** self ._field_config ) # type: ignore
276
+ ) -> Sequence [ MultiModalFieldElem ]:
277
+ field = self .field_cls (key = key , modality = self .modality )
278
+ return field .build_elems (batch , ** self .field_config ) # type: ignore
281
279
282
280
283
- class MultiModalKwargs (UserDict [str , NestedTensors ]):
281
+ class MultiModalKwargsItem (UserDict [str , MultiModalFieldElem ]):
282
+ """
283
+ A collection of :class:`MultiModalFieldElem`
284
+ corresponding to a data item in :class:`MultiModalDataItems`.
284
285
"""
285
- A dictionary that represents the keyword arguments to
286
- :meth:`~torch.nn.Module.forward`.
287
286
288
- The metadata :code:`items_by_key` defines how to split batched keyword
289
- arguments corresponding to each data item in :class:`MultiModalDataItems`:
287
+ @staticmethod
288
+ def from_elems (elems : Sequence [MultiModalFieldElem ]):
289
+ return MultiModalKwargsItem ({elem .field .key : elem for elem in elems })
290
290
291
- - For a keyword argument, we can access the :code:`i` th item in the batch
292
- via :code:`items_by_key[key][i]`.
293
- - We can gather the keyword arguments belonging to a modality by finding
294
- the keys with items that belong to that modality, then accessing
295
- the :code:`i` th item in the batch for each such key.
291
+ @ property
292
+ def modality ( self ) -> str :
293
+ modalities = { elem . field . modality for elem in self . data . values ()}
294
+ assert len ( modalities ) == 1 , f"Found different modalities= { modalities } "
295
+ return next ( iter ( modalities ))
296
296
297
- Example:
298
297
299
- .. code-block:: python
300
-
301
- # All items belong to the "image" modality
302
- items_by_key={
303
- "pixel_values": [a, b, c, d], # "image" modality
304
- "image_grid_thw": [e, f, g, h], # "image" modality
305
- "pixel_values_video": [h, i, j], # "video" modality
306
- "video_grid_thw": [k, l, m], # "video" modality
307
- }
298
+ # NOTE: UserDict is for V0 compatibility.
299
+ # V1 should access individual items via `get_item`.
300
+ class MultiModalKwargs (UserDict [str , NestedTensors ]):
301
+ """
302
+ A dictionary that represents the keyword arguments to
303
+ :meth:`~torch.nn.Module.forward`.
308
304
309
- - The keyword arguments belonging to the first image are
310
- :code:`{"pixel_values": a, "image_grid_thw": e}`.
311
- - The keyword arguments belonging to the second video are
312
- :code:`{"pixel_values_video": i, "video_grid_thw": l}`.
305
+ The metadata :code:`items` enables us to obtain the keyword arguments
306
+ corresponding to each data item in :class:`MultiModalDataItems`, via
307
+ :meth:`get_item` and :meth:`get_items`.
313
308
"""
314
309
315
310
@staticmethod
316
311
def from_hf_inputs (
317
312
hf_inputs : BatchFeature ,
318
313
config_by_key : Mapping [str , MultiModalFieldConfig ],
319
- * ,
320
- enable_sanity_checks : bool = False ,
321
314
):
322
315
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
323
316
# We assume that those fields are not used in vLLM
324
- items_by_key = {
325
- key : config .build_items (key , batch )
326
- for key , config in config_by_key .items ()
327
- if (batch := hf_inputs .get (key )) is not None
328
- }
329
-
330
- return MultiModalKwargs .from_items_by_key (
331
- items_by_key ,
332
- enable_sanity_checks = enable_sanity_checks ,
333
- )
317
+ elems_by_key = dict [str , Sequence [MultiModalFieldElem ]]()
318
+ keys_by_modality = defaultdict [str , set [str ]](set )
319
+ for key , config in config_by_key .items ():
320
+ batch = hf_inputs .get (key )
321
+ if batch is not None :
322
+ elems = config .build_elems (key , batch )
323
+ if len (elems ) > 0 :
324
+ elems_by_key [key ] = elems
325
+ keys_by_modality [config .modality ].add (key )
326
+
327
+ items = list [MultiModalKwargsItem ]()
328
+ for modality , keys in keys_by_modality .items ():
329
+ elems_in_modality = {k : elems_by_key [k ] for k in keys }
330
+ batch_sizes = {k : len (v ) for k , v in elems_in_modality .items ()}
331
+
332
+ if len (set (batch_sizes .values ())) > 1 :
333
+ raise ValueError (
334
+ f"Cannot merge different batch sizes for { modality = } ! "
335
+ f"Found: { batch_sizes = } " )
336
+
337
+ batch_size = next (iter (batch_sizes .values ()))
338
+ for item_idx in range (batch_size ):
339
+ elems = [v [item_idx ] for v in elems_in_modality .values ()]
340
+ items .append (MultiModalKwargsItem .from_elems (elems ))
341
+
342
+ return MultiModalKwargs .from_items (items )
334
343
335
344
@staticmethod
336
- def from_items_by_key (
337
- items_by_key : Mapping [str , list [MultiModalFieldItem ]],
338
- * ,
339
- enable_sanity_checks : bool = False ,
340
- ) -> "MultiModalKwargs" :
345
+ def from_items (items : Sequence [MultiModalKwargsItem ]):
346
+ """Construct a new :class:`MultiModalKwargs` from multiple items."""
347
+ elems_by_key = defaultdict [str , list [MultiModalFieldElem ]](list )
348
+ for item in items :
349
+ for key , elem in item .items ():
350
+ elems_by_key [key ].append (elem )
351
+
341
352
data = {
342
- key : items [0 ].field .reduce (items ).data
343
- for key , items in items_by_key .items () if len (items ) > 0
353
+ key : elems [0 ].field .reduce (elems ).data
354
+ for key , elems in elems_by_key .items () if len (elems ) > 0
344
355
}
345
356
346
- return MultiModalKwargs (data ,
347
- items_by_key = items_by_key ,
348
- enable_sanity_checks = enable_sanity_checks )
357
+ return MultiModalKwargs (data , items = items )
349
358
350
359
def __init__ (
351
360
self ,
352
361
data : Mapping [str , NestedTensors ],
353
362
* ,
354
- items_by_key : Mapping [str , list [MultiModalFieldItem ]] = {},
355
- enable_sanity_checks : bool = False ,
363
+ items : Optional [Sequence [MultiModalKwargsItem ]] = None ,
356
364
) -> None :
357
365
super ().__init__ (data )
358
366
359
- # Shallow copy to avoid footgun in case a defaultdict is passed in
360
- self ._items_by_key = dict (items_by_key )
367
+ items_by_modality = full_groupby ( items or [], key = lambda x : x . modality )
368
+ self ._items_by_modality = dict (items_by_modality )
361
369
362
- keys_by_modality = defaultdict [str , set [str ]](set )
363
- for key , items in items_by_key .items ():
364
- for item in items :
365
- keys_by_modality [item .field .modality ].add (key )
366
-
367
- self ._keys_by_modality = dict (keys_by_modality )
368
-
369
- if enable_sanity_checks :
370
- for modality , keys in keys_by_modality .items ():
371
- items_in_modality = {k : items_by_key [k ] for k in keys }
372
- batch_sizes = {k : len (v ) for k , v in items_in_modality .items ()}
373
- batch_size = next (iter (batch_sizes .values ()), 0 )
374
- assert all (bs == batch_size
375
- for bs in batch_sizes .values ()), dict (
376
- modality = modality ,
377
- batch_sizes = batch_sizes ,
378
- items_by_key = items_by_key )
370
+ @property
371
+ def modalities (self ):
372
+ return self ._items_by_modality .keys ()
379
373
380
374
@staticmethod
381
375
def _try_stack (nested_tensors : NestedTensors ) -> NestedTensors :
@@ -452,58 +446,44 @@ def as_kwargs(
452
446
def __eq__ (self , other : object ) -> bool :
453
447
if not isinstance (other , self .__class__ ):
454
448
return False
455
- if self ._items_by_key != other ._items_by_key :
449
+ if self ._items_by_modality != other ._items_by_modality :
456
450
return False
457
451
458
452
ks = self .keys ()
459
453
return (ks == other .keys ()
460
454
and all (nested_tensors_equal (self [k ], other [k ]) for k in ks ))
461
455
462
- def get_item (self , key : str , item_index : int ) -> MultiModalFieldItem :
463
- return self ._items_by_key [key ][item_index ]
456
+ def _validate_modality (self , method_name : str , modality : str ) -> None :
457
+ if not self ._items_by_modality :
458
+ raise RuntimeError (
459
+ f"`{ method_name } ` is not supported when "
460
+ "MultiModalKwargs is not initialized with `items`" )
464
461
465
- def get_items_by_modality (
466
- self ,
467
- modality : str ,
468
- item_index : int ,
469
- ) -> Mapping [str , MultiModalFieldItem ]:
470
- """
471
- Get the keyword arguments corresponding to an item identified by
472
- its modality and index.
473
- """
474
- if modality not in self ._keys_by_modality :
475
- available_modalities = set (self ._keys_by_modality .keys ())
462
+ if modality not in self ._items_by_modality :
463
+ available_modalities = set (self ._items_by_modality .keys ())
476
464
raise KeyError (f"Modality { modality !r} not found. "
477
465
f"Available modalities: { available_modalities } " )
478
466
479
- keys_to_gather = self ._keys_by_modality [modality ]
467
+ def get_item_count (self , modality : str ) -> int :
468
+ """Get the number of items belonging to a modality."""
469
+ self ._validate_modality ("get_item_count" , modality )
470
+ return len (self ._items_by_modality [modality ])
480
471
481
- return {
482
- key : self .get_item (key , item_index )
483
- for key in keys_to_gather if key in self
484
- }
472
+ def get_item (self , modality : str , item_index : int ) -> MultiModalKwargsItem :
473
+ """
474
+ Get the keyword arguments corresponding to an item identified by
475
+ its modality and index.
476
+ """
477
+ self ._validate_modality ("get_item" , modality )
478
+ return self ._items_by_modality [modality ][item_index ]
485
479
486
- @staticmethod
487
- def from_items_by_modality (
488
- items_by_modality : Mapping [str , list [Mapping [str ,
489
- MultiModalFieldItem ]]],
490
- * ,
491
- enable_sanity_checks : bool = False ,
492
- ) -> "MultiModalKwargs" :
480
+ def get_items (self , modality : str ) -> Sequence [MultiModalKwargsItem ]:
493
481
"""
494
- Construct a new :class:`MultiModalKwargs` from multiple items returned
495
- by :meth:`get_fields_by_modality` .
482
+ Get the keyword arguments corresponding to each item belonging to
483
+ a modality .
496
484
"""
497
- items_by_key = defaultdict [str , list [MultiModalFieldItem ]](list )
498
- for fields in items_by_modality .values ():
499
- for field in fields :
500
- for k , v in field .items ():
501
- items_by_key [k ].append (v )
502
-
503
- return MultiModalKwargs .from_items_by_key (
504
- items_by_key ,
505
- enable_sanity_checks = enable_sanity_checks ,
506
- )
485
+ self ._validate_modality ("get_items" , modality )
486
+ return self ._items_by_modality [modality ]
507
487
508
488
509
489
MultiModalPlaceholderDict = Mapping [str , Sequence [PlaceholderRange ]]
0 commit comments