Skip to content

Commit 2a91bc1

Browse files
NickLucchezRzRzRzRzRzRzR
authored andcommitted
[Model] Add SupportsMultiModal.get_language_model interface (vllm-project#16007)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
1 parent c8e2b50 commit 2a91bc1

33 files changed

+116
-0
lines changed

docs/source/contributing/model/multimodal.md

+11
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ Further update the model as follows:
7979
return inputs_embeds
8080
```
8181

82+
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.
83+
84+
```python
85+
class YourModelForImage2Seq(nn.Module):
86+
...
87+
88+
def get_language_model(self) -> torch.nn.Module:
89+
# Change `language_model` according to your implementation.
90+
return self.language_model
91+
```
92+
8293
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
8394

8495
```diff

vllm/model_executor/models/aria.py

+3
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ def _process_image_input(
605605

606606
return self.multi_modal_projector(image_outputs, image_attn_mask)
607607

608+
def get_language_model(self) -> torch.nn.Module:
609+
return self.language_model
610+
608611
def get_multimodal_embeddings(
609612
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
610613
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/aya_vision.py

+3
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ def _parse_and_validate_image_input(
424424
num_patches=num_patches,
425425
)
426426

427+
def get_language_model(self) -> torch.nn.Module:
428+
return self.language_model
429+
427430
def get_multimodal_embeddings(
428431
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
429432
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/blip2.py

+3
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,9 @@ def _process_image_input(self,
627627

628628
return self.language_projection(query_output)
629629

630+
def get_language_model(self) -> torch.nn.Module:
631+
return self.language_model
632+
630633
def get_multimodal_embeddings(
631634
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
632635
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/chameleon.py

+3
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,9 @@ def _parse_and_validate_image_input(
988988
data=self._validate_pixel_values(pixel_values),
989989
)
990990

991+
def get_language_model(self) -> torch.nn.Module:
992+
return self.model
993+
991994
def get_multimodal_embeddings(
992995
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
993996
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/deepseek_vl2.py

+3
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,9 @@ def _process_image_input(
604604
return self._pixel_values_to_embedding(
605605
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
606606

607+
def get_language_model(self) -> torch.nn.Module:
608+
return self.language_model
609+
607610
def get_multimodal_embeddings(
608611
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
609612
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/florence2.py

+3
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,9 @@ def _process_image_input(
10501050
pixel_values = image_input["data"]
10511051
return self._encode_image(pixel_values)
10521052

1053+
def get_language_model(self) -> torch.nn.Module:
1054+
return self.language_model
1055+
10531056
def get_multimodal_embeddings(
10541057
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
10551058
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/fuyu.py

+3
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def _process_image_input(
341341

342342
return vision_embeddings_flat.split(patches_per_image, dim=0)
343343

344+
def get_language_model(self) -> torch.nn.Module:
345+
return self.language_model
346+
344347
def get_multimodal_embeddings(
345348
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
346349
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/gemma3_mm.py

+3
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,9 @@ def _process_image_input(
591591
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
592592
]
593593

594+
def get_language_model(self) -> torch.nn.Module:
595+
return self.language_model
596+
594597
def get_multimodal_embeddings(
595598
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
596599
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/glm4v.py

+3
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ def _process_image_input(
596596

597597
return self.transformer.vision(pixel_values)
598598

599+
def get_language_model(self) -> torch.nn.Module:
600+
return self.transformer
601+
599602
def get_multimodal_embeddings(
600603
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
601604
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/idefics3.py

+3
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,9 @@ def _process_image_input(
710710
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
711711
]
712712

713+
def get_language_model(self) -> torch.nn.Module:
714+
return self.model
715+
713716
def get_multimodal_embeddings(
714717
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
715718
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/interfaces.py

+12
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def get_multimodal_embeddings(
5656
"""
5757
...
5858

59+
def get_language_model(self) -> torch.nn.Module:
60+
"""
61+
Returns the underlying language model used for text generation.
62+
63+
This is typically the `torch.nn.Module` instance responsible for
64+
processing the merged multimodal embeddings and producing hidden states
65+
66+
Returns:
67+
torch.nn.Module: The core language model component.
68+
"""
69+
...
70+
5971
# Only for models that support v0 chunked prefill
6072
# TODO(ywang96): Remove this overload once v0 is deprecated
6173
@overload

vllm/model_executor/models/internvl.py

+3
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
884884
else:
885885
self.visual_token_mask = None
886886

887+
def get_language_model(self) -> torch.nn.Module:
888+
return self.language_model
889+
887890
def get_multimodal_embeddings(
888891
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
889892
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/llava.py

+3
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,9 @@ def _process_image_input(
674674
image_embeds = torch.split(image_embeds, feature_sizes)
675675
return image_embeds
676676

677+
def get_language_model(self) -> torch.nn.Module:
678+
return self.language_model
679+
677680
def get_multimodal_embeddings(
678681
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
679682
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/llava_next.py

+3
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,9 @@ def _process_image_input(
480480
for i, patch_features_batch in enumerate(patch_embeddings)
481481
]
482482

483+
def get_language_model(self) -> torch.nn.Module:
484+
return self.language_model
485+
483486
def get_multimodal_embeddings(
484487
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
485488
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/llava_next_video.py

+3
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,9 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
421421

422422
return [e.flatten(0, 1) for e in embeds]
423423

424+
def get_language_model(self) -> torch.nn.Module:
425+
return self.language_model
426+
424427
def get_multimodal_embeddings(
425428
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
426429
video_input = self._parse_and_validate_video_input(**kwargs)

vllm/model_executor/models/llava_onevision.py

+3
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,9 @@ def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
852852
image_feature = image_feature.view(batch_frames, -1, dim)
853853
return image_feature
854854

855+
def get_language_model(self) -> torch.nn.Module:
856+
return self.language_model
857+
855858
def get_multimodal_embeddings(
856859
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
857860
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)

vllm/model_executor/models/minicpmv.py

+3
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,9 @@ def _process_multimodal_inputs(self, modalities: dict):
892892

893893
return multimodal_embeddings
894894

895+
def get_language_model(self) -> torch.nn.Module:
896+
return self.llm
897+
895898
def get_multimodal_embeddings(
896899
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
897900
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)

vllm/model_executor/models/mistral3.py

+3
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,9 @@ def _process_image_input(
514514
image_embeds = (image_embeds, )
515515
return image_embeds
516516

517+
def get_language_model(self) -> torch.nn.Module:
518+
return self.language_model
519+
517520
def get_multimodal_embeddings(
518521
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
519522
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/mllama.py

+3
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,9 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
13251325
cross_attention_states = cross_attention_states_flat
13261326
return cross_attention_states
13271327

1328+
def get_language_model(self) -> torch.nn.Module:
1329+
return self.language_model
1330+
13281331
def get_cross_attention_states(
13291332
self,
13301333
image_inputs: MllamaImagePixelInputs,

vllm/model_executor/models/mllama4.py

+3
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,9 @@ def _process_image_input(
742742
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
743743
]
744744

745+
def get_language_model(self) -> torch.nn.Module:
746+
return self.language_model
747+
745748
def get_multimodal_embeddings(self,
746749
**kwargs) -> Optional[MultiModalEmbeddings]:
747750
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/molmo.py

+3
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,9 @@ def _process_image_input(
14881488
)
14891489
]
14901490

1491+
def get_language_model(self) -> torch.nn.Module:
1492+
return self.model
1493+
14911494
def get_multimodal_embeddings(
14921495
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
14931496
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/paligemma.py

+3
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def _process_image_input(
323323

324324
return self.multi_modal_projector(image_features)
325325

326+
def get_language_model(self) -> torch.nn.Module:
327+
return self.language_model
328+
326329
def get_multimodal_embeddings(
327330
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
328331
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/phi3v.py

+3
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,9 @@ def _process_image_input(
674674

675675
return image_embeds
676676

677+
def get_language_model(self) -> torch.nn.Module:
678+
return self.language_model
679+
677680
def get_multimodal_embeddings(
678681
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
679682
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/phi4mm.py

+3
Original file line numberDiff line numberDiff line change
@@ -1802,3 +1802,6 @@ def get_mm_mapping(self) -> MultiModelKeys:
18021802
connector=["audio_projection_for_vision", "audio_projection"],
18031803
tower_model=["vision_encoder", "embed_tokens_extend"],
18041804
)
1805+
1806+
def get_language_model(self) -> torch.nn.Module:
1807+
return self.model

vllm/model_executor/models/pixtral.py

+3
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,9 @@ def _process_image_input(
396396
image_embeds = torch.split(image_embeds, feature_sizes)
397397
return image_embeds
398398

399+
def get_language_model(self) -> torch.nn.Module:
400+
return self.language_model
401+
399402
def get_multimodal_embeddings(
400403
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
401404
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/qwen2_5_vl.py

+3
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,9 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
967967
**kwargs)
968968
return modalities
969969

970+
def get_language_model(self) -> torch.nn.Module:
971+
return self.language_model
972+
970973
def get_multimodal_embeddings(
971974
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
972975

vllm/model_executor/models/qwen2_audio.py

+3
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ def _process_audio_input(self,
355355
return torch.split(masked_audio_features,
356356
audio_output_lengths.flatten().tolist())
357357

358+
def get_language_model(self) -> torch.nn.Module:
359+
return self.language_model
360+
358361
def get_multimodal_embeddings(
359362
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
360363
audio_input = self._parse_and_validate_audio_input(**kwargs)

vllm/model_executor/models/qwen2_vl.py

+3
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,9 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
12761276

12771277
return modalities
12781278

1279+
def get_language_model(self) -> torch.nn.Module:
1280+
return self.language_model
1281+
12791282
def get_multimodal_embeddings(
12801283
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
12811284

vllm/model_executor/models/qwen_vl.py

+3
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,9 @@ def _process_image_input(self,
740740

741741
return self.transformer.visual(image_input["data"])
742742

743+
def get_language_model(self) -> torch.nn.Module:
744+
return self.transformer
745+
743746
def get_multimodal_embeddings(
744747
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
745748
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/skyworkr1v.py

+3
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
889889
else:
890890
self.visual_token_mask = None
891891

892+
def get_language_model(self) -> torch.nn.Module:
893+
return self.language_model
894+
892895
def get_multimodal_embeddings(
893896
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
894897
image_input = self._parse_and_validate_image_input(**kwargs)

vllm/model_executor/models/ultravox.py

+3
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,9 @@ def _process_audio_input(
563563
]
564564
return flattened_embeddings.split(embed_lens)
565565

566+
def get_language_model(self) -> torch.nn.Module:
567+
return self.language_model
568+
566569
def get_multimodal_embeddings(
567570
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
568571
audio_input = self._parse_and_validate_audio_input(**kwargs)

vllm/model_executor/models/whisper.py

+3
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ def forward(
692692
)
693693
return decoder_outputs
694694

695+
def get_language_model(self) -> torch.nn.Module:
696+
return self.model.decoder
697+
695698
def get_multimodal_embeddings(
696699
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
697700
# TODO: This method does not obey the interface for SupportsMultiModal.

0 commit comments

Comments
 (0)