Skip to content

Commit 7d394b5

Browse files
committed
Add and update multi-image examples
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 9270036 commit 7d394b5

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

examples/offline_inference_vision_language_multi_image.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class ModelRequestData(NamedTuple):
2424
llm: LLM
2525
prompt: str
26-
stop_token_ids: Optional[List[str]]
26+
stop_token_ids: Optional[List[int]]
2727
image_data: List[Image]
2828
chat_template: Optional[str]
2929

@@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
4444
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
4545
"<|im_start|>assistant\n")
4646
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
47+
4748
return ModelRequestData(
4849
llm=llm,
4950
prompt=prompt,
5051
stop_token_ids=stop_token_ids,
5152
image_data=[fetch_image(url) for url in image_urls],
52-
chat_template=None)
53+
chat_template=None,
54+
)
5355

5456

5557
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
@@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
166168
limit_mm_per_prompt={"image": len(image_urls)},
167169
)
168170

169-
prompt = f"<|image|><|image|><|begin_of_text|>{question}"
171+
placeholders = "<|image|>" * len(image_urls)
172+
prompt = f"{placeholders}<|begin_of_text|>{question}"
170173
return ModelRequestData(
171174
llm=llm,
172175
prompt=prompt,
@@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
209212
)
210213

211214

215+
def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
216+
model_name = "mistral-community/pixtral-12b"
217+
218+
# Adjust this as necessary to fit in GPU
219+
llm = LLM(
220+
model=model_name,
221+
max_model_len=8192,
222+
max_num_seqs=2,
223+
tensor_parallel_size=2,
224+
limit_mm_per_prompt={"image": len(image_urls)},
225+
)
226+
227+
placeholders = "[IMG]" * len(image_urls)
228+
prompt = f"<s>[INST]{question}\n{placeholders}[/INST]"
229+
stop_token_ids = None
230+
231+
return ModelRequestData(
232+
llm=llm,
233+
prompt=prompt,
234+
stop_token_ids=stop_token_ids,
235+
image_data=[fetch_image(url) for url in image_urls],
236+
chat_template=None,
237+
)
238+
239+
212240
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
213241
# num_crops is an override kwarg to the multimodal image processor;
214242
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
@@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
244272
)
245273

246274

247-
def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
275+
def load_qwen_vl_chat(question: str,
276+
image_urls: List[str]) -> ModelRequestData:
248277
model_name = "Qwen/Qwen-VL-Chat"
249278
llm = LLM(
250279
model=model_name,
@@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
274303

275304
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
276305
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
306+
277307
return ModelRequestData(
278308
llm=llm,
279309
prompt=prompt,
@@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
348378
"mllama": load_mllama,
349379
"NVLM_D": load_nvlm_d,
350380
"phi3_v": load_phi3v,
351-
"qwen_vl_chat": load_qwenvl_chat,
381+
"pixtral_hf": load_phi3v,
382+
"qwen_vl_chat": load_qwen_vl_chat,
352383
"qwen2_vl": load_qwen2_vl,
353384
}
354385

0 commit comments

Comments
 (0)