Skip to content

Commit c6bc821

Browse files
yma11kylesayrs
authored andcommitted
[Bugfix][Model] fix mllama multi-image (vllm-project#14883)
Signed-off-by: yan ma <[email protected]> Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9c7dd92 commit c6bc821

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def _run_test(
212212
with vllm_runner(model,
213213
dtype=dtype,
214214
max_model_len=4096,
215-
max_num_seqs=2,
215+
max_num_seqs=3,
216216
tensor_parallel_size=tensor_parallel_size,
217217
distributed_executor_backend=distributed_executor_backend,
218218
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT

vllm/model_executor/models/mllama.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,11 +1235,34 @@ def sample(
12351235
next_tokens = self.sampler(logits, sampling_metadata)
12361236
return next_tokens
12371237

1238+
def unpack_data(self,
1239+
image_data: Union[List[torch.Tensor], torch.Tensor],
1240+
padding_value=0) -> torch.Tensor:
1241+
if isinstance(image_data, torch.Tensor):
1242+
# torch.Tensor
1243+
return image_data
1244+
else:
1245+
assert isinstance(
1246+
image_data[0],
1247+
torch.Tensor), "Image data is not properly batched."
1248+
# List[torch.Tensor]
1249+
bsz = len(image_data)
1250+
max_length = max(t.size(0) for t in image_data)
1251+
trailing_dims = image_data[0].shape[1:]
1252+
for data in image_data:
1253+
cur_trailing_dims = data.shape[1:]
1254+
assert cur_trailing_dims == trailing_dims
1255+
output_tensor = torch.full((bsz, max_length, *trailing_dims),
1256+
padding_value,
1257+
dtype=image_data[0].dtype,
1258+
device=image_data[0].device)
1259+
for i, t in enumerate(image_data):
1260+
output_tensor[i, :t.size(0)] = t
1261+
return output_tensor
1262+
12381263
def _parse_and_validate_image_input(self, **kwargs: object):
12391264
# tensor with the same shape will be batched together by
12401265
# MultiModalKwargs.batch, so pixel_values here can be:
1241-
# - List[List[torch.Tensor]]:
1242-
# with shape (num_tiles, 3, image_res, image_res)
12431266
# - List[torch.Tensor]:
12441267
# with shape (num_image, num_tiles, 3, image_res, image_res)
12451268
# - torch.Tensor:
@@ -1274,10 +1297,9 @@ def _parse_and_validate_image_input(self, **kwargs: object):
12741297

12751298
return MllamaImagePixelInputs(
12761299
type="pixel_values",
1277-
data=pixel_values,
1278-
aspect_ratio_ids=aspect_ratio_ids,
1279-
aspect_ratio_mask=aspect_ratio_mask,
1280-
)
1300+
data=self.unpack_data(pixel_values),
1301+
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
1302+
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
12811303

12821304
if image_embeds is not None:
12831305
raise NotImplementedError

0 commit comments

Comments
 (0)