Skip to content

Commit 7268388

Browse files
DarkLight1337Isotr0py
authored andcommitted
[Bugfix] Fix image input for Pixtral-HF (vllm-project#11741)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent eef7a71 commit 7268388

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

examples/offline_inference_vision_language_multi_image.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class ModelRequestData(NamedTuple):
2424
llm: LLM
2525
prompt: str
26-
stop_token_ids: Optional[List[str]]
26+
stop_token_ids: Optional[List[int]]
2727
image_data: List[Image]
2828
chat_template: Optional[str]
2929

@@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
4444
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
4545
"<|im_start|>assistant\n")
4646
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
47+
4748
return ModelRequestData(
4849
llm=llm,
4950
prompt=prompt,
5051
stop_token_ids=stop_token_ids,
5152
image_data=[fetch_image(url) for url in image_urls],
52-
chat_template=None)
53+
chat_template=None,
54+
)
5355

5456

5557
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
@@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
166168
limit_mm_per_prompt={"image": len(image_urls)},
167169
)
168170

169-
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
171+
placeholders = "<|image|>" * len(image_urls)
172+
prompt = f"{placeholders}<|begin_of_text|>{question}"
170173
return ModelRequestData(
171174
llm=llm,
172175
prompt=prompt,
@@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
209212
)
210213

211214

215+
def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
216+
model_name = "mistral-community/pixtral-12b"
217+
218+
# Adjust this as necessary to fit in GPU
219+
llm = LLM(
220+
model=model_name,
221+
max_model_len=8192,
222+
max_num_seqs=2,
223+
tensor_parallel_size=2,
224+
limit_mm_per_prompt={"image": len(image_urls)},
225+
)
226+
227+
placeholders = "[IMG]" * len(image_urls)
228+
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
229+
stop_token_ids = None
230+
231+
return ModelRequestData(
232+
llm=llm,
233+
prompt=prompt,
234+
stop_token_ids=stop_token_ids,
235+
image_data=[fetch_image(url) for url in image_urls],
236+
chat_template=None,
237+
)
238+
239+
212240
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
213241
# num_crops is an override kwarg to the multimodal image processor;
214242
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
@@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
244272
)
245273

246274

247-
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
275+
def load_qwen_vl_chat(question: str,
276+
image_urls: List[str]) -> ModelRequestData:
248277
model_name = "Qwen/Qwen-VL-Chat"
249278
llm = LLM(
250279
model=model_name,
@@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
274303

275304
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
276305
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
306+
277307
return ModelRequestData(
278308
llm=llm,
279309
prompt=prompt,
@@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
348378
"mllama": load_mllama,
349379
"NVLM_D": load_nvlm_d,
350380
"phi3_v": load_phi3v,
351-
"qwen_vl_chat": load_qwenvl_chat,
381+
"pixtral_hf": load_pixtral_hf,
382+
"qwen_vl_chat": load_qwen_vl_chat,
352383
"qwen2_vl": load_qwen2_vl,
353384
}
354385

vllm/model_executor/models/llava.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,12 @@ def _parse_and_validate_image_input(
546546
raise ValueError("Incorrect type of pixel values. "
547547
f"Got type: {type(pixel_values)}")
548548

549+
if self.config.vision_config.model_type == "pixtral":
550+
return LlavaImagePixelInputs(
551+
type="pixel_values",
552+
data=flatten_bn(pixel_values),
553+
)
554+
549555
return LlavaImagePixelInputs(
550556
type="pixel_values",
551557
data=self._validate_pixel_values(

vllm/model_executor/models/pixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def get_num_image_tokens(
774774
) -> int:
775775
return get_pixtral_hf_image_feature_size(
776776
image_size=self.vision_config.image_size,
777-
patch_size=self.get_image_size(),
777+
patch_size=self.vision_config.patch_size,
778778
)
779779

780780
def get_max_image_tokens(self) -> int:

vllm/model_executor/models/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,15 @@ def flatten_bn(
281281
...
282282

283283

284+
@overload
285+
def flatten_bn(
286+
x: Union[List[torch.Tensor], torch.Tensor],
287+
*,
288+
concat: bool = False,
289+
) -> Union[List[torch.Tensor], torch.Tensor]:
290+
...
291+
292+
284293
def flatten_bn(
285294
x: Union[List[torch.Tensor], torch.Tensor],
286295
*,

0 commit comments

Comments
 (0)