Skip to content

Commit 3c0ff91

Browse files
[Bugfix] Fix Mllama interleaved images input support (#15564)
Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Chen Zhang <[email protected]>
1 parent 2bc4be4 commit 3c0ff91

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed

examples/offline_inference/vision_language_multi_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
229229
limit_mm_per_prompt={"image": len(image_urls)},
230230
)
231231

232-
placeholders = "<|image|>" * len(image_urls)
233-
prompt = f"{placeholders}<|begin_of_text|>{question}"
232+
img_prompt = "Given the first image <|image|> and the second image<|image|>"
233+
prompt = f"<|begin_of_text|>{img_prompt}, {question}?"
234234
return ModelRequestData(
235235
engine_args=engine_args,
236236
prompt=prompt,

vllm/model_executor/models/mllama.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -180,19 +180,66 @@ def apply(
180180
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
181181
return_mm_hashes)
182182

183+
image_token_id = self.info.get_hf_config().image_token_index
183184
# Check that the number of image tokens in the decoder prompt matches
184185
# the number of images provided in mm_data
185-
num_image_tokens = mm_inputs['prompt_token_ids'].count(
186-
self.info.get_hf_config().image_token_index)
186+
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
187187
image_data = mm_data.get("image", [])
188188
num_images = 1 if isinstance(image_data, Image) else len(image_data)
189189
if num_image_tokens != num_images:
190190
raise ValueError(
191191
f"The number of image tokens ({num_image_tokens}) must be"
192192
f" the same as the number of images ({num_images})")
193193

194+
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
195+
# P0 & P1 do cross attention with placeholder of <IMG0>
196+
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
197+
# Example input to encoder and decoder:
198+
# {
199+
# 'encoder': {
200+
# 'type': 'token',
201+
# 'prompt_token_ids': [128256, 128256, ..., 128256],
202+
# 'prompt': '<|image|><|image|>...<|image|>',
203+
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
204+
# },
205+
# 'decoder': {
206+
# 'type': 'token',
207+
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
208+
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
209+
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
210+
# },
211+
# }
212+
213+
if mm_data:
214+
# Since only the last group of consecutive images
215+
# are attended by the decoded tokens, we only need to
216+
# get the number of tokens for those images.
217+
token_per_chunk = self.info.get_token_per_chunk_from_config()
218+
num_decode_images = self._get_num_image_in_last_group(
219+
mm_inputs["prompt_token_ids"])
220+
num_encode_images = num_images - num_decode_images
221+
222+
# Set encoder prompt length based on the number of tiles.
223+
# This tells the block manager to allocate correct number
224+
# of slots for encoder tokens.
225+
num_tiles = mm_inputs["mm_kwargs"]["num_tiles"]
226+
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
227+
num_tokens = decode_tiles * token_per_chunk
228+
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
229+
] * num_tokens
230+
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
231+
194232
return mm_inputs
195233

234+
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
235+
num_images = 0
236+
for token_id in prompt_token_ids[::-1]:
237+
if token_id == self.info.get_hf_config().image_token_index:
238+
num_images += 1
239+
elif num_images > 0:
240+
break
241+
return num_images
242+
196243
def _call_hf_processor(
197244
self,
198245
prompt: str,
@@ -210,19 +257,7 @@ def _call_hf_processor(
210257
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
211258
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
212259
processed_outputs[k] = processed_outputs[k].squeeze(0)
213-
# Example input to encoder and decoder:
214-
# {
215-
# 'encoder': {
216-
# 'type': 'token',
217-
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
218-
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
219-
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
220-
# },
221-
# 'decoder': {
222-
# 'type': 'token',
223-
# 'prompt_token_ids': [128000],
224-
# },
225-
# }
260+
226261
processed_token_ids = processed_outputs.pop("input_ids")
227262
start_idx, end_idx = 0, processed_token_ids.size(1)
228263
processed_prompt_text = tokenizer.decode(processed_token_ids[0])

0 commit comments

Comments
 (0)