Skip to content

Commit b3cf368

Browse files
authored
[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
1 parent c8525f0 commit b3cf368

22 files changed

+249
-150
lines changed

examples/offline_inference/vision_language.py

Lines changed: 176 additions & 118 deletions
Large diffs are not rendered by default.

vllm/model_executor/models/aria.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def _process_image_input(
602602

603603
return self.multi_modal_projector(image_outputs, image_attn_mask)
604604

605-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
605+
def get_multimodal_embeddings(
606+
self, **kwargs
607+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
606608
image_input = self._parse_and_validate_image_input(**kwargs)
607609
if image_input is None:
608610
return None

vllm/model_executor/models/blip2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,9 @@ def _process_image_input(self,
628628

629629
return self.language_projection(query_output)
630630

631-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
631+
def get_multimodal_embeddings(
632+
self, **kwargs
633+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
632634
image_input = self._parse_and_validate_image_input(**kwargs)
633635
if image_input is None:
634636
return None

vllm/model_executor/models/chameleon.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,9 @@ def _parse_and_validate_image_input(
986986
data=self._validate_pixel_values(pixel_values),
987987
)
988988

989-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
989+
def get_multimodal_embeddings(
990+
self, **kwargs
991+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
990992
image_input = self._parse_and_validate_image_input(**kwargs)
991993
if image_input is None:
992994
return None

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,9 @@ def _process_image_input(
606606
return self._pixel_values_to_embedding(
607607
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
608608

609-
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
609+
def get_multimodal_embeddings(
610+
self, **kwargs: object
611+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
610612
image_input = self._parse_and_validate_image_input(**kwargs)
611613
if image_input is None:
612614
return None

vllm/model_executor/models/florence2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,9 @@ def _process_image_input(
10371037
pixel_values = image_input["data"]
10381038
return self._encode_image(pixel_values)
10391039

1040-
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
1040+
def get_multimodal_embeddings(
1041+
self, **kwargs: object
1042+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
10411043
image_input = self._parse_and_validate_image_input(**kwargs)
10421044
if image_input is None:
10431045
return None

vllm/model_executor/models/fuyu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
""" PyTorch Fuyu model."""
1919
import math
2020
from collections.abc import Iterable, Mapping, Sequence
21-
from typing import List, Literal, Optional, Set, Tuple, TypedDict
21+
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -327,7 +327,9 @@ def _process_image_input(
327327
image_patches_flat)
328328
return vision_embeddings_flat.split(patches_per_image, dim=0)
329329

330-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
330+
def get_multimodal_embeddings(
331+
self, **kwargs
332+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
331333
image_input = self._parse_and_validate_image_input(**kwargs)
332334
if image_input is None:
333335
return None

vllm/model_executor/models/glm4v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def _process_image_input(
595595

596596
return self.transformer.vision(pixel_values)
597597

598-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
598+
def get_multimodal_embeddings(
599+
self, **kwargs
600+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
599601
image_input = self._parse_and_validate_image_input(**kwargs)
600602
if image_input is None:
601603
return None

vllm/model_executor/models/idefics3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
617617
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
618618
self.sampler = get_sampler()
619619

620-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
620+
def get_multimodal_embeddings(
621+
self, **kwargs
622+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
621623
image_input = self.model._parse_and_validate_image_input(**kwargs)
622624
if image_input is None:
623625
return None

vllm/model_executor/models/interfaces.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Protocol, Type, Union, overload, runtime_checkable)
55

66
import torch
7+
from torch import Tensor
78
from typing_extensions import TypeIs, TypeVar
89

910
from vllm.logger import init_logger
@@ -15,12 +16,11 @@
1516

1617
if TYPE_CHECKING:
1718
from vllm.attention import AttentionMetadata
18-
from vllm.multimodal.inputs import NestedTensors # noqa: F401
1919
from vllm.sequence import IntermediateTensors
2020

2121
logger = init_logger(__name__)
2222

23-
T = TypeVar("T", default="NestedTensors")
23+
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
2424

2525

2626
@runtime_checkable
@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
3636
MRO of your model class.
3737
"""
3838

39-
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
39+
def get_multimodal_embeddings(self, **kwargs) -> T:
4040
"""
4141
Returns multimodal embeddings generated from multimodal kwargs
4242
to be merged with text embeddings.
@@ -59,18 +59,18 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
5959
@overload
6060
def get_input_embeddings(
6161
self,
62-
input_ids: torch.Tensor,
62+
input_ids: Tensor,
6363
multimodal_embeddings: Optional[T] = None,
6464
attn_metadata: Optional["AttentionMetadata"] = None,
65-
) -> torch.Tensor:
65+
) -> Tensor:
6666
...
6767

6868
@overload
6969
def get_input_embeddings(
7070
self,
71-
input_ids: torch.Tensor,
71+
input_ids: Tensor,
7272
multimodal_embeddings: Optional[T] = None,
73-
) -> torch.Tensor:
73+
) -> Tensor:
7474
"""
7575
Returns the input embeddings merged from the text embeddings from
7676
input_ids and the multimodal embeddings generated from multimodal
@@ -210,7 +210,7 @@ def forward(
210210
self,
211211
*,
212212
intermediate_tensors: Optional["IntermediateTensors"],
213-
) -> Union[torch.Tensor, "IntermediateTensors"]:
213+
) -> Union[Tensor, "IntermediateTensors"]:
214214
"""
215215
Accept :class:`IntermediateTensors` when PP rank > 0.
216216
@@ -237,7 +237,7 @@ def forward(
237237
self,
238238
*,
239239
intermediate_tensors: Optional["IntermediateTensors"],
240-
) -> Union[torch.Tensor, "IntermediateTensors"]:
240+
) -> Union[Tensor, "IntermediateTensors"]:
241241
...
242242

243243

vllm/model_executor/models/internvl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
904904
else:
905905
self.visual_token_mask = None
906906

907-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
907+
def get_multimodal_embeddings(
908+
self, **kwargs
909+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
908910
image_input = self._parse_and_validate_image_input(**kwargs)
909911
if image_input is None:
910912
return None

vllm/model_executor/models/llava.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,9 @@ def _process_image_input(self,
635635
image_features = self._process_image_pixels(image_input)
636636
return self.multi_modal_projector(image_features)
637637

638-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
638+
def get_multimodal_embeddings(
639+
self, **kwargs
640+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
639641
image_input = self._parse_and_validate_image_input(**kwargs)
640642
if image_input is None:
641643
return None

vllm/model_executor/models/llava_next.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def _process_image_input(
479479
for i, patch_features_batch in enumerate(patch_embeddings)
480480
]
481481

482-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
482+
def get_multimodal_embeddings(
483+
self, **kwargs
484+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
483485
image_input = self._parse_and_validate_image_input(**kwargs)
484486
if image_input is None:
485487
return None

vllm/model_executor/models/llava_next_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
420420
raise ValueError(
421421
f"Unsupported type of video input {type(video_pixels)}")
422422

423-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
423+
def get_multimodal_embeddings(
424+
self, **kwargs
425+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
424426
video_input = self._parse_and_validate_video_input(**kwargs)
425427
if video_input is None:
426428
return None

vllm/model_executor/models/molmo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
PromptInsertion, PromptUpdate)
5151
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
5252
from vllm.sequence import IntermediateTensors
53-
from vllm.utils import JSONTree, json_map_leaves
53+
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
5454

5555
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
5656
SupportsQuant)
@@ -1576,21 +1576,24 @@ def _get_mm_embeds(
15761576

15771577
return embeds_in_batch
15781578

1579-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
1579+
def get_multimodal_embeddings(
1580+
self, **kwargs
1581+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
15801582
image_input = self._parse_and_validate_image_input(**kwargs)
15811583
if image_input is None:
15821584
return None
15831585

15841586
image_features = self._process_image_input(image_input)
15851587

1586-
return [
1588+
nested_embeds = [
15871589
self._get_mm_embeds(*args) for args in zip(
15881590
image_features,
15891591
image_input["feat_is_patch"],
15901592
image_input["num_crops"],
15911593
image_input["embed_is_patch"],
15921594
)
15931595
]
1596+
return flatten_2d_lists(nested_embeds)
15941597

15951598
def get_input_embeddings(
15961599
self,

vllm/model_executor/models/paligemma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def _process_image_input(
263263

264264
return self.multi_modal_projector(image_features)
265265

266-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
266+
def get_multimodal_embeddings(
267+
self, **kwargs
268+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
267269
image_input = self._parse_and_validate_image_input(**kwargs)
268270
if image_input is None:
269271
return None

vllm/model_executor/models/phi3v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,9 @@ def _process_image_input(
648648

649649
return image_embeds
650650

651-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
651+
def get_multimodal_embeddings(
652+
self, **kwargs
653+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
652654
image_input = self._parse_and_validate_image_input(**kwargs)
653655
if image_input is None:
654656
return None

vllm/model_executor/models/pixtral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def sampler(self):
220220

221221
return get_sampler()
222222

223-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
223+
def get_multimodal_embeddings(
224+
self, **kwargs
225+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
224226
image_input, image_tokens = self._parse_and_validate_image_input(
225227
**kwargs)
226228
if image_input is None:

vllm/model_executor/models/qwen2_audio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def _process_audio_input(self,
356356
return torch.split(masked_audio_features,
357357
audio_output_lengths.flatten().tolist())
358358

359-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
359+
def get_multimodal_embeddings(
360+
self, **kwargs
361+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
360362
audio_input = self._parse_and_validate_audio_input(**kwargs)
361363
if audio_input is None:
362364
return None

vllm/model_executor/models/qwen_vl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,9 @@ def _process_image_input(self,
740740

741741
return self.transformer.visual(image_input["data"])
742742

743-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
743+
def get_multimodal_embeddings(
744+
self, **kwargs
745+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
744746
image_input = self._parse_and_validate_image_input(**kwargs)
745747
if image_input is None:
746748
return None

vllm/model_executor/models/ultravox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,9 @@ def _process_audio_input(
476476

477477
return result
478478

479-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
479+
def get_multimodal_embeddings(
480+
self, **kwargs
481+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
480482
audio_input = self._parse_and_validate_audio_input(**kwargs)
481483
if audio_input is None:
482484
return None

vllm/model_executor/models/whisper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,9 @@ def forward(
692692
)
693693
return decoder_outputs
694694

695-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
695+
def get_multimodal_embeddings(
696+
self, **kwargs
697+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
696698
# TODO: This method does not obey the interface for SupportsMultiModal.
697699
# Refactor this once encoder/decoder support is implemented in V1.
698700
audio_input = self._parse_and_validate_audio_input(**kwargs)

0 commit comments

Comments
 (0)