Skip to content

Commit 72c3433

Browse files
ywang96rasmith
authored andcommitted
[Bugfix][VLM] Fix mixed-modality inference backward compatibility for V0 (vllm-project#12313)
Signed-off-by: Roger Wang <[email protected]>
1 parent 102e975 commit 72c3433

File tree

2 files changed

+92
-28
lines changed

2 files changed

+92
-28
lines changed

vllm/model_executor/models/llava_onevision.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ def apply_pooling(self, image_features, stride=2):
816816
return image_feature
817817

818818
def get_multimodal_embeddings(
819-
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
819+
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
820820
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
821821
if not modalities:
822822
return None
@@ -842,8 +842,7 @@ def get_multimodal_embeddings(
842842
def get_input_embeddings(
843843
self,
844844
input_ids: torch.Tensor,
845-
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
846-
str]]] = None,
845+
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
847846
) -> torch.Tensor:
848847
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
849848
if multimodal_embeddings is not None:
@@ -852,6 +851,34 @@ def get_input_embeddings(
852851
[self.config.image_token_index, self.config.video_token_index])
853852
return inputs_embeds
854853

854+
def get_input_embeddings_v0(
855+
self,
856+
input_ids: torch.Tensor,
857+
image_input: Optional[NestedTensors] = None,
858+
video_input: Optional[NestedTensors] = None,
859+
) -> torch.Tensor:
860+
861+
inputs_embeds = self.get_input_embeddings(input_ids)
862+
if image_input is not None:
863+
image_embeds = self._process_image_input(image_input)
864+
inputs_embeds = merge_multimodal_embeddings(
865+
input_ids,
866+
inputs_embeds,
867+
image_embeds,
868+
placeholder_token_id=self.config.image_token_index,
869+
)
870+
871+
if video_input is not None:
872+
video_embeds = self._process_video_pixels(video_input)
873+
inputs_embeds = merge_multimodal_embeddings(
874+
input_ids,
875+
inputs_embeds,
876+
video_embeds,
877+
placeholder_token_id=self.config.video_token_index,
878+
)
879+
880+
return inputs_embeds
881+
855882
def forward(
856883
self,
857884
input_ids: torch.Tensor,
@@ -871,13 +898,21 @@ def forward(
871898
if intermediate_tensors is not None:
872899
inputs_embeds = None
873900

874-
# NOTE: In v1, inputs_embeds is always generated at model runner, this
875-
# condition is for v0 compatibility.
901+
# NOTE: In v1, inputs_embeds is always generated at model runner from
902+
# `get_multimodal_embeddings` and `get_input_embeddings`, this
903+
# condition is only for v0 compatibility.
876904
elif inputs_embeds is None:
877-
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
878-
inputs_embeds = self.get_input_embeddings(input_ids,
879-
multimodal_embeddings)
880-
input_ids = None
905+
image_input = self._parse_and_validate_image_input(**kwargs)
906+
video_input = self._parse_and_validate_video_input(**kwargs)
907+
908+
if image_input is None and video_input is None:
909+
inputs_embeds = None
910+
else:
911+
inputs_embeds = self.get_input_embeddings_v0(
912+
input_ids,
913+
image_input=image_input,
914+
video_input=video_input)
915+
input_ids = None
881916

882917
hidden_states = self.language_model.model(input_ids,
883918
positions,

vllm/model_executor/models/qwen2_vl.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from vllm.multimodal import MULTIMODAL_REGISTRY
5656
from vllm.multimodal.inputs import (ImageItem, ModalityData,
5757
MultiModalFieldConfig, MultiModalKwargs,
58-
NestedTensors, VideoItem)
58+
VideoItem)
5959
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
6060
MultiModalDataItems, MultiModalDataParser)
6161
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -1233,7 +1233,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
12331233
return modalities
12341234

12351235
def get_multimodal_embeddings(
1236-
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
1236+
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
12371237

12381238
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
12391239
if not modalities:
@@ -1260,8 +1260,7 @@ def get_multimodal_embeddings(
12601260
def get_input_embeddings(
12611261
self,
12621262
input_ids: torch.Tensor,
1263-
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
1264-
str]]] = None,
1263+
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
12651264
) -> torch.Tensor:
12661265
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
12671266
if multimodal_embeddings is not None:
@@ -1270,6 +1269,33 @@ def get_input_embeddings(
12701269
[self.config.image_token_id, self.config.video_token_id])
12711270
return inputs_embeds
12721271

1272+
def get_input_embeddings_v0(
1273+
self,
1274+
input_ids: torch.Tensor,
1275+
image_input: Optional[tuple[torch.Tensor, ...]] = None,
1276+
video_input: Optional[tuple[torch.Tensor, ...]] = None,
1277+
) -> torch.Tensor:
1278+
1279+
inputs_embeds = self.get_input_embeddings(input_ids)
1280+
if image_input is not None:
1281+
image_embeds = self._process_image_input(image_input)
1282+
inputs_embeds = merge_multimodal_embeddings(
1283+
input_ids,
1284+
inputs_embeds,
1285+
image_embeds,
1286+
placeholder_token_id=self.config.image_token_id,
1287+
)
1288+
1289+
if video_input is not None:
1290+
video_embeds = self._process_video_input(video_input)
1291+
inputs_embeds = merge_multimodal_embeddings(
1292+
input_ids,
1293+
inputs_embeds,
1294+
video_embeds,
1295+
placeholder_token_id=self.config.video_token_id,
1296+
)
1297+
return inputs_embeds
1298+
12731299
def forward(
12741300
self,
12751301
input_ids: torch.Tensor,
@@ -1303,22 +1329,25 @@ def forward(
13031329
if intermediate_tensors is not None:
13041330
inputs_embeds = None
13051331

1306-
# NOTE: In v1, inputs_embeds is always generated at model runner, this
1307-
# condition is for v0 compatibility.
1332+
# NOTE: In v1, inputs_embeds is always generated at model runner from
1333+
# `get_multimodal_embeddings` and `get_input_embeddings`, this
1334+
# condition is only for v0 compatibility.
13081335
elif inputs_embeds is None:
1309-
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
1310-
1311-
# We need to check for usage of mrope here in case there is
1312-
# multimodal data.
1313-
# TODO (ywang96): move this to model runner in V1.
1314-
if multimodal_embeddings is not None and uses_mrope(self.config):
1315-
assert positions.ndim == 2 and positions.size(0) == 3, (
1316-
"multimodal section rotary embedding requires "
1317-
f"(3, seq_len) positions, but got {positions.size()}")
1318-
1319-
inputs_embeds = self.get_input_embeddings(input_ids,
1320-
multimodal_embeddings)
1321-
input_ids = None
1336+
image_input = self._parse_and_validate_image_input(**kwargs)
1337+
video_input = self._parse_and_validate_video_input(**kwargs)
1338+
1339+
if image_input is None and video_input is None:
1340+
inputs_embeds = None
1341+
else:
1342+
if uses_mrope(self.config):
1343+
assert positions.ndim == 2 and positions.size(0) == 3, (
1344+
"multimodal section rotary embedding requires "
1345+
f"(3, seq_len) positions, but got {positions.size()}")
1346+
inputs_embeds = self.get_input_embeddings_v0(
1347+
input_ids,
1348+
image_input=image_input,
1349+
video_input=video_input)
1350+
input_ids = None
13221351

13231352
hidden_states = self.language_model.model(
13241353
input_ids=input_ids,

0 commit comments

Comments
 (0)