Skip to content

[Model] Add SupportsMultiModal.get_language_model interface #16007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ Further update the model as follows:
return inputs_embeds
```

- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.

```python
class YourModelForImage2Seq(nn.Module):
...

def get_language_model(self) -> torch.nn.Module:
# Change `language_model` according to your implementation.
return self.language_model
```

- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.

```diff
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ def _process_image_input(

return self.multi_modal_projector(image_outputs, image_attn_mask)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def _parse_and_validate_image_input(
num_patches=num_patches,
)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,9 @@ def _process_image_input(self,

return self.language_projection(query_output)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,9 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def get_language_model(self) -> torch.nn.Module:
return self.model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ def _process_image_input(
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,9 @@ def _process_image_input(
pixel_values = image_input["data"]
return self._encode_image(pixel_values)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def _process_image_input(

return vision_embeddings_flat.split(patches_per_image, dim=0)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,9 @@ def _process_image_input(
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ def _process_image_input(

return self.transformer.vision(pixel_values)

def get_language_model(self) -> torch.nn.Module:
return self.transformer

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,9 @@ def _process_image_input(
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]

def get_language_model(self) -> torch.nn.Module:
return self.model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def get_multimodal_embeddings(
"""
...

def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...

# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
else:
self.visual_token_mask = None

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,9 @@ def _process_image_input(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ def _process_image_input(
for i, patch_features_batch in enumerate(patch_embeddings)
]

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):

return [e.flatten(0, 1) for e in embeds]

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,9 @@ def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,9 @@ def _process_multimodal_inputs(self, modalities: dict):

return multimodal_embeddings

def get_language_model(self) -> torch.nn.Module:
return self.llm

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def _process_image_input(
image_embeds = (image_embeds, )
return image_embeds

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,9 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
cross_attention_states = cross_attention_states_flat
return cross_attention_states

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,9 @@ def _process_image_input(
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,9 @@ def _process_image_input(
)
]

def get_language_model(self) -> torch.nn.Module:
return self.model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def _process_image_input(

return self.multi_modal_projector(image_features)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,9 @@ def _process_image_input(

return image_embeds

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,3 +1802,6 @@ def get_mm_mapping(self) -> MultiModelKeys:
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)

def get_language_model(self) -> torch.nn.Module:
return self.model
3 changes: 3 additions & 0 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ def _process_image_input(
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,9 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
**kwargs)
return modalities

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def _process_audio_input(self,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,9 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

return modalities

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,9 @@ def _process_image_input(self,

return self.transformer.visual(image_input["data"])

def get_language_model(self) -> torch.nn.Module:
return self.transformer

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/skyworkr1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
else:
self.visual_token_mask = None

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,9 @@ def _process_audio_input(
]
return flattened_embeddings.split(embed_lens)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ def forward(
)
return decoder_outputs

def get_language_model(self) -> torch.nn.Module:
return self.model.decoder

def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
# TODO: This method does not obey the interface for SupportsMultiModal.
Expand Down