@@ -1235,11 +1235,34 @@ def sample(
1235
1235
next_tokens = self .sampler (logits , sampling_metadata )
1236
1236
return next_tokens
1237
1237
1238
+ def unpack_data (self ,
1239
+ image_data : Union [List [torch .Tensor ], torch .Tensor ],
1240
+ padding_value = 0 ) -> torch .Tensor :
1241
+ if isinstance (image_data , torch .Tensor ):
1242
+ # torch.Tensor
1243
+ return image_data
1244
+ else :
1245
+ assert isinstance (
1246
+ image_data [0 ],
1247
+ torch .Tensor ), "Image data is not properly batched."
1248
+ # List[torch.Tensor]
1249
+ bsz = len (image_data )
1250
+ max_length = max (t .size (0 ) for t in image_data )
1251
+ trailing_dims = image_data [0 ].shape [1 :]
1252
+ for data in image_data :
1253
+ cur_trailing_dims = data .shape [1 :]
1254
+ assert cur_trailing_dims == trailing_dims
1255
+ output_tensor = torch .full ((bsz , max_length , * trailing_dims ),
1256
+ padding_value ,
1257
+ dtype = image_data [0 ].dtype ,
1258
+ device = image_data [0 ].device )
1259
+ for i , t in enumerate (image_data ):
1260
+ output_tensor [i , :t .size (0 )] = t
1261
+ return output_tensor
1262
+
1238
1263
def _parse_and_validate_image_input (self , ** kwargs : object ):
1239
1264
# tensor with the same shape will be batched together by
1240
1265
# MultiModalKwargs.batch, so pixel_values here can be:
1241
- # - List[List[torch.Tensor]]:
1242
- # with shape (num_tiles, 3, image_res, image_res)
1243
1266
# - List[torch.Tensor]:
1244
1267
# with shape (num_image, num_tiles, 3, image_res, image_res)
1245
1268
# - torch.Tensor:
@@ -1274,10 +1297,9 @@ def _parse_and_validate_image_input(self, **kwargs: object):
1274
1297
1275
1298
return MllamaImagePixelInputs (
1276
1299
type = "pixel_values" ,
1277
- data = pixel_values ,
1278
- aspect_ratio_ids = aspect_ratio_ids ,
1279
- aspect_ratio_mask = aspect_ratio_mask ,
1280
- )
1300
+ data = self .unpack_data (pixel_values ),
1301
+ aspect_ratio_ids = self .unpack_data (aspect_ratio_ids ),
1302
+ aspect_ratio_mask = self .unpack_data (aspect_ratio_mask ))
1281
1303
1282
1304
if image_embeds is not None :
1283
1305
raise NotImplementedError
0 commit comments