23
23
class ModelRequestData (NamedTuple ):
24
24
llm : LLM
25
25
prompt : str
26
- stop_token_ids : Optional [List [str ]]
26
+ stop_token_ids : Optional [List [int ]]
27
27
image_data : List [Image ]
28
28
chat_template : Optional [str ]
29
29
@@ -44,12 +44,14 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
44
44
prompt = (f"<|im_start|>user\n { placeholders } { question } <|im_end|>\n "
45
45
"<|im_start|>assistant\n " )
46
46
stop_token_ids = [93532 , 93653 , 944 , 93421 , 1019 , 93653 , 93519 ]
47
+
47
48
return ModelRequestData (
48
49
llm = llm ,
49
50
prompt = prompt ,
50
51
stop_token_ids = stop_token_ids ,
51
52
image_data = [fetch_image (url ) for url in image_urls ],
52
- chat_template = None )
53
+ chat_template = None ,
54
+ )
53
55
54
56
55
57
def load_h2onvl (question : str , image_urls : List [str ]) -> ModelRequestData :
@@ -166,7 +168,8 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
166
168
limit_mm_per_prompt = {"image" : len (image_urls )},
167
169
)
168
170
169
- prompt = f"<|image|><|image|><|begin_of_text|>{ question } "
171
+ placeholders = "<|image|>" * len (image_urls )
172
+ prompt = f"{ placeholders } <|begin_of_text|>{ question } "
170
173
return ModelRequestData (
171
174
llm = llm ,
172
175
prompt = prompt ,
@@ -209,6 +212,31 @@ def load_nvlm_d(question: str, image_urls: List[str]):
209
212
)
210
213
211
214
215
+ def load_pixtral_hf (question : str , image_urls : List [str ]) -> ModelRequestData :
216
+ model_name = "mistral-community/pixtral-12b"
217
+
218
+ # Adjust this as necessary to fit in GPU
219
+ llm = LLM (
220
+ model = model_name ,
221
+ max_model_len = 8192 ,
222
+ max_num_seqs = 2 ,
223
+ tensor_parallel_size = 2 ,
224
+ limit_mm_per_prompt = {"image" : len (image_urls )},
225
+ )
226
+
227
+ placeholders = "[IMG]" * len (image_urls )
228
+ prompt = f"<s>[INST]{ question } \n { placeholders } [/INST]"
229
+ stop_token_ids = None
230
+
231
+ return ModelRequestData (
232
+ llm = llm ,
233
+ prompt = prompt ,
234
+ stop_token_ids = stop_token_ids ,
235
+ image_data = [fetch_image (url ) for url in image_urls ],
236
+ chat_template = None ,
237
+ )
238
+
239
+
212
240
def load_phi3v (question : str , image_urls : List [str ]) -> ModelRequestData :
213
241
# num_crops is an override kwarg to the multimodal image processor;
214
242
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
@@ -244,7 +272,8 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
244
272
)
245
273
246
274
247
- def load_qwenvl_chat (question : str , image_urls : List [str ]) -> ModelRequestData :
275
+ def load_qwen_vl_chat (question : str ,
276
+ image_urls : List [str ]) -> ModelRequestData :
248
277
model_name = "Qwen/Qwen-VL-Chat"
249
278
llm = LLM (
250
279
model = model_name ,
@@ -274,6 +303,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
274
303
275
304
stop_tokens = ["<|endoftext|>" , "<|im_start|>" , "<|im_end|>" ]
276
305
stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
306
+
277
307
return ModelRequestData (
278
308
llm = llm ,
279
309
prompt = prompt ,
@@ -348,7 +378,8 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
348
378
"mllama" : load_mllama ,
349
379
"NVLM_D" : load_nvlm_d ,
350
380
"phi3_v" : load_phi3v ,
351
- "qwen_vl_chat" : load_qwenvl_chat ,
381
+ "pixtral_hf" : load_pixtral_hf ,
382
+ "qwen_vl_chat" : load_qwen_vl_chat ,
352
383
"qwen2_vl" : load_qwen2_vl ,
353
384
}
354
385
0 commit comments