55
55
from vllm .multimodal import MULTIMODAL_REGISTRY
56
56
from vllm .multimodal .inputs import (ImageItem , ModalityData ,
57
57
MultiModalFieldConfig , MultiModalKwargs ,
58
- NestedTensors , VideoItem )
58
+ VideoItem )
59
59
from vllm .multimodal .parse import (ImageSize , ModalityDataItems ,
60
60
MultiModalDataItems , MultiModalDataParser )
61
61
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
@@ -1233,7 +1233,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
1233
1233
return modalities
1234
1234
1235
1235
def get_multimodal_embeddings (
1236
- self , ** kwargs ) -> Optional [List [ Tuple [ NestedTensors , str ] ]]:
1236
+ self , ** kwargs ) -> Optional [tuple [ torch . Tensor , ... ]]:
1237
1237
1238
1238
modalities = self ._parse_and_validate_multimodal_inputs (** kwargs )
1239
1239
if not modalities :
@@ -1260,8 +1260,7 @@ def get_multimodal_embeddings(
1260
1260
def get_input_embeddings (
1261
1261
self ,
1262
1262
input_ids : torch .Tensor ,
1263
- multimodal_embeddings : Optional [List [Tuple [NestedTensors ,
1264
- str ]]] = None ,
1263
+ multimodal_embeddings : Optional [tuple [torch .Tensor , ...]] = None ,
1265
1264
) -> torch .Tensor :
1266
1265
inputs_embeds = self .language_model .get_input_embeddings (input_ids )
1267
1266
if multimodal_embeddings is not None :
@@ -1270,6 +1269,33 @@ def get_input_embeddings(
1270
1269
[self .config .image_token_id , self .config .video_token_id ])
1271
1270
return inputs_embeds
1272
1271
1272
+ def get_input_embeddings_v0 (
1273
+ self ,
1274
+ input_ids : torch .Tensor ,
1275
+ image_input : Optional [tuple [torch .Tensor , ...]] = None ,
1276
+ video_input : Optional [tuple [torch .Tensor , ...]] = None ,
1277
+ ) -> torch .Tensor :
1278
+
1279
+ inputs_embeds = self .get_input_embeddings (input_ids )
1280
+ if image_input is not None :
1281
+ image_embeds = self ._process_image_input (image_input )
1282
+ inputs_embeds = merge_multimodal_embeddings (
1283
+ input_ids ,
1284
+ inputs_embeds ,
1285
+ image_embeds ,
1286
+ placeholder_token_id = self .config .image_token_id ,
1287
+ )
1288
+
1289
+ if video_input is not None :
1290
+ video_embeds = self ._process_video_input (video_input )
1291
+ inputs_embeds = merge_multimodal_embeddings (
1292
+ input_ids ,
1293
+ inputs_embeds ,
1294
+ video_embeds ,
1295
+ placeholder_token_id = self .config .video_token_id ,
1296
+ )
1297
+ return inputs_embeds
1298
+
1273
1299
def forward (
1274
1300
self ,
1275
1301
input_ids : torch .Tensor ,
@@ -1303,22 +1329,25 @@ def forward(
1303
1329
if intermediate_tensors is not None :
1304
1330
inputs_embeds = None
1305
1331
1306
- # NOTE: In v1, inputs_embeds is always generated at model runner, this
1307
- # condition is for v0 compatibility.
1332
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
1333
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
1334
+ # condition is only for v0 compatibility.
1308
1335
elif inputs_embeds is None :
1309
- multimodal_embeddings = self .get_multimodal_embeddings (** kwargs )
1310
-
1311
- # We need to check for usage of mrope here in case there is
1312
- # multimodal data.
1313
- # TODO (ywang96): move this to model runner in V1.
1314
- if multimodal_embeddings is not None and uses_mrope (self .config ):
1315
- assert positions .ndim == 2 and positions .size (0 ) == 3 , (
1316
- "multimodal section rotary embedding requires "
1317
- f"(3, seq_len) positions, but got { positions .size ()} " )
1318
-
1319
- inputs_embeds = self .get_input_embeddings (input_ids ,
1320
- multimodal_embeddings )
1321
- input_ids = None
1336
+ image_input = self ._parse_and_validate_image_input (** kwargs )
1337
+ video_input = self ._parse_and_validate_video_input (** kwargs )
1338
+
1339
+ if image_input is None and video_input is None :
1340
+ inputs_embeds = None
1341
+ else :
1342
+ if uses_mrope (self .config ):
1343
+ assert positions .ndim == 2 and positions .size (0 ) == 3 , (
1344
+ "multimodal section rotary embedding requires "
1345
+ f"(3, seq_len) positions, but got { positions .size ()} " )
1346
+ inputs_embeds = self .get_input_embeddings_v0 (
1347
+ input_ids ,
1348
+ image_input = image_input ,
1349
+ video_input = video_input )
1350
+ input_ids = None
1322
1351
1323
1352
hidden_states = self .language_model .model (
1324
1353
input_ids = input_ids ,
0 commit comments