@@ -134,8 +134,10 @@ def initialize(self, args):
134
134
'model_type' ]
135
135
136
136
assert self .model_type in [
137
- 'llava' , 'blip2-opt' , 'vila' , 'mllama'
138
- ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila and mllama. Got { self .model_type } ."
137
+ 'llava' , 'blip2-opt' , 'vila' , 'mllama' , 'llava_onevision'
138
+ ], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila, mllama and llava_onevision. Got { self .model_type } ."
139
+
140
+ assert self .model_type != 'llava_onevison' or self .max_num_images is None or self .max_num_images <= 1 , f"LLaVA-OneVsion is not support multi image inference currently."
139
141
140
142
llm_model_path = model_config ['parameters' ]['gpt_model_path' ][
141
143
'string_value' ]
@@ -146,15 +148,17 @@ def initialize(self, args):
146
148
llm_model_config ["pretrained_config" ]["vocab_size" ])
147
149
self ._setup_ptable_shape (llm_model_config )
148
150
149
- self .vision_preprocessor = VisionPreProcessor (
150
- self .model_type , AutoProcessor .from_pretrained (tokenizer_dir ),
151
- model_config )
151
+ if self .model_type == 'mllama' or self .model_type == 'llava_onevision' :
152
+ self .vision_preprocessor = VisionPreProcessor (
153
+ self .model_type ,
154
+ AutoProcessor .from_pretrained (tokenizer_dir ), model_config )
152
155
153
156
# Parse model output configs and convert Triton types to numpy types
154
157
output_names = [
155
158
"INPUT_ID" , "DECODER_INPUT_ID" , "REQUEST_INPUT_LEN" ,
156
159
"REQUEST_DECODER_INPUT_LEN" , "BAD_WORDS_IDS" , "STOP_WORDS_IDS" ,
157
- "OUT_END_ID" , "OUT_PAD_ID" , "OUT_PROMPT_TABLE_EXTRA_IDS"
160
+ "OUT_END_ID" , "OUT_PAD_ID" , "OUT_PROMPT_TABLE_EXTRA_IDS" ,
161
+ "PIXEL_VALUES" , "IMAGE_SIZES"
158
162
]
159
163
input_names = ["EMBEDDING_BIAS_WORDS" , "EMBEDDING_BIAS_WEIGHTS" ]
160
164
for input_name in input_names :
@@ -270,8 +274,50 @@ def execute(self, requests):
270
274
assert prompt_table_extra_id .shape [
271
275
1 ] == 1 , "Multiple IDs cannot be provided for a single image"
272
276
277
+ # Preprocessing vision input passed as a url or bytes tensor
278
+ img_urls = pb_utils .get_input_tensor_by_name (request , 'IMAGE_URL' )
279
+ image_bytes = pb_utils .get_input_tensor_by_name (
280
+ request , 'IMAGE_BYTES' )
281
+ video_bytes = pb_utils .get_input_tensor_by_name (
282
+ request , 'VIDEO_BYTES' )
283
+ vision_processed_tensors = []
284
+ visual_tokens = []
285
+ if self .is_multimodal and (img_urls or image_bytes or video_bytes ):
286
+ assert self .vision_preprocessor != None , "Vision preprocessor for preparing images before encoding is None"
287
+ processed_tensors = {}
288
+ if self .model_type == 'mllama' :
289
+ processed_tensors = self .vision_preprocessor .mllama_process (
290
+ queries = query .astype (str ).tolist (),
291
+ img_urls = img_urls ,
292
+ image_bytes = image_bytes ,
293
+ )
294
+ elif self .model_type == 'llava_onevision' :
295
+ if video_bytes is None :
296
+ processed_tensors , visual_tokens = self .vision_preprocessor .llava_onevision_process_image (
297
+ queries = query .astype (str ).tolist (),
298
+ img_urls = img_urls ,
299
+ image_bytes = image_bytes ,
300
+ )
301
+ else :
302
+ processed_tensors , visual_tokens = self .vision_preprocessor .llava_onevision_process_video (
303
+ queries = query .astype (str ).tolist (),
304
+ video_bytes = video_bytes ,
305
+ )
306
+ else :
307
+ raise ValueError (
308
+ "Unsupported model type for IMAGE_BYTES or IMAGE_URL inputs"
309
+ )
310
+ vision_processed_tensors = [
311
+ pb_utils .Tensor .from_dlpack (k , v )
312
+ for k , v in processed_tensors .items ()
313
+ ]
314
+ else :
315
+ assert self .model_type != "mllama" and self .model_type != "llava_onevision" , "Image processing requires IMAGE_BYTES or IMAGE_URL to be provided"
316
+
273
317
# Preprocessing input data.
274
- input_id , request_input_len = self ._create_request (query )
318
+ # For the LLaVA_OneVision model, num_visual_features is not a fixed value
319
+ input_id , request_input_len = self ._create_request (
320
+ query , visual_tokens )
275
321
if decoder_query is not None :
276
322
decoder_input_id , request_decoder_input_len = self ._create_request (
277
323
decoder_query )
@@ -294,24 +340,6 @@ def execute(self, requests):
294
340
input_id [i ] >= self .vocab_size ,
295
341
prompt_table_extra_id [i ], 0 )
296
342
297
- # Preprocessing vision input passed as a url or bytes tensor
298
- img_urls = pb_utils .get_input_tensor_by_name (request , 'IMAGE_URL' )
299
- image_bytes = pb_utils .get_input_tensor_by_name (
300
- request , 'IMAGE_BYTES' )
301
- if img_urls or image_bytes :
302
- assert self .vision_preprocessor != None , "Vision preprocessor for preparing images before encoding is None"
303
- vision_processed_tensors = self .vision_preprocessor .process (
304
- queries = query .astype (str ).tolist (),
305
- img_urls = img_urls ,
306
- image_bytes = image_bytes ,
307
- ) if self .is_multimodal else {}
308
- vision_processed_tensors = [
309
- pb_utils .Tensor .from_dlpack (k , v )
310
- for k , v in vision_processed_tensors .items ()
311
- ]
312
- else :
313
- vision_processed_tensors = []
314
-
315
343
# Create output tensors. You need pb_utils.Tensor
316
344
# objects to create pb_utils.InferenceResponse.
317
345
input_id_tensor = pb_utils .Tensor (
@@ -489,7 +517,7 @@ def _process_multi_image_inputs(self, query, image_token_index=-200):
489
517
490
518
return start_ids
491
519
492
- def _create_request (self , query ):
520
+ def _create_request (self , query , visual_tokens = None ):
493
521
"""
494
522
query : batch string (2D numpy array)
495
523
"""
@@ -513,7 +541,7 @@ def _create_request(self, query):
513
541
]
514
542
515
543
if self .is_multimodal :
516
- if 'blip2' in self .model_type :
544
+ if 'blip2' in self .model_type or 'mllama' == self . model_type :
517
545
pre_prompt = None
518
546
post_prompt = None
519
547
elif 'llava' == self .model_type :
@@ -522,12 +550,9 @@ def _create_request(self, query):
522
550
elif 'vila' == self .model_type :
523
551
pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
524
552
post_prompt = " ASSISTANT:"
525
- elif 'mllama' == self .model_type :
526
- pre_prompt = None
527
- post_prompt = None
528
-
529
- fake_prompt_id = np .arange (self .vocab_size ,
530
- self .vocab_size + self .ptable_shape [1 ])
553
+ elif 'llava_onevision' == self .model_type :
554
+ pre_prompt = "<|im_start|>user "
555
+ post_prompt = "<|im_end|><|im_start|>assistant\n "
531
556
532
557
pre_prompt_id = np .array (
533
558
self .tokenizer .encode (
@@ -552,7 +577,26 @@ def _create_request(self, query):
552
577
concatenated_ids )
553
578
start_ids = self ._setup_fake_prompts (query .shape [0 ],
554
579
batch_split_prompts )
580
+ elif self .model_type == 'llava_onevision' :
581
+ fake_prompt_ids = []
582
+ extra_id = np .array (
583
+ self .tokenizer .encode (
584
+ '\n ' ,
585
+ add_special_tokens = self .add_special_tokens ,
586
+ padding = True ))
587
+ for tokens in visual_tokens :
588
+ prompt_id = np .arange (self .vocab_size ,
589
+ self .vocab_size + tokens )
590
+ fake_prompt_ids .append (prompt_id )
591
+ start_ids = [
592
+ np .concatenate ((pre_prompt_id , prompt_id , extra_id , ids ,
593
+ post_prompt_id ),
594
+ axis = 0 )
595
+ for prompt_id , ids in zip (fake_prompt_ids , start_ids )
596
+ ]
555
597
else :
598
+ fake_prompt_id = np .arange (
599
+ self .vocab_size , self .vocab_size + self .ptable_shape [1 ])
556
600
start_ids = [
557
601
np .concatenate (
558
602
(pre_prompt_id , fake_prompt_id , ids , post_prompt_id ),
@@ -725,7 +769,7 @@ def load_images_from_urls(self, img_urls):
725
769
Image .open (requests .get (img_url , stream = True ).raw ))
726
770
return images
727
771
728
- def process (self , queries , img_urls = None , image_bytes = None ):
772
+ def mllama_process (self , queries , img_urls = None , image_bytes = None ):
729
773
vision_processed_tensors = {}
730
774
if img_urls is not None or image_bytes is not None :
731
775
if img_urls is not None :
@@ -774,3 +818,91 @@ def process(self, queries, img_urls=None, image_bytes=None):
774
818
val , self .output_str_dtypes [key ])
775
819
vision_processed_tensors [key ] = val
776
820
return vision_processed_tensors
821
+
822
+ def llava_onevision_process_image (self ,
823
+ queries ,
824
+ img_urls = None ,
825
+ image_bytes = None ):
826
+
827
+ import torch
828
+ vision_processed_tensors = {}
829
+ if img_urls is not None :
830
+ # download and read images
831
+ images = [
832
+ self .load_images_from_urls (urls )
833
+ for urls in img_urls .as_numpy ()
834
+ ]
835
+ else :
836
+ images = [
837
+ img for img_list in self .load_images_tensor (image_bytes )
838
+ for img in img_list
839
+ ]
840
+
841
+ batch_size = len (images )
842
+ assert len (
843
+ queries
844
+ ) == batch_size , f"Image must have the same batch size as Query."
845
+ preprocessor_outputs = {}
846
+ possible_output_names = ['PIXEL_VALUES' , 'IMAGE_SIZES' ]
847
+ visual_tokens = []
848
+ for batch_id in range (batch_size ):
849
+ # Preprocess images and query
850
+ processed_vision_data = self .vision_model_processor (
851
+ images = images [batch_id ], text = '<image>' , return_tensors = "pt" )
852
+ visual_tokens .append (processed_vision_data ['input_ids' ].shape [1 ])
853
+
854
+ # Create vision output tensors
855
+ for key in possible_output_names :
856
+ val = processed_vision_data .get (key .lower ())
857
+ if val is not None :
858
+ if key not in preprocessor_outputs :
859
+ preprocessor_outputs [key ] = []
860
+ preprocessor_outputs [key ].append (val )
861
+
862
+ max_patch = max (x .shape [1 ]
863
+ for x in preprocessor_outputs ['PIXEL_VALUES' ])
864
+ preprocessor_outputs ['PIXEL_VALUES' ] = [
865
+ torch .nn .functional .pad (
866
+ image , (0 , 0 , 0 , 0 , 0 , 0 , 0 , max_patch - image .shape [1 ], 0 , 0 ),
867
+ mode = 'constant' )
868
+ for image in preprocessor_outputs ['PIXEL_VALUES' ]
869
+ ]
870
+ for key , tensor_list in preprocessor_outputs .items ():
871
+ val = self .convert_tensor_list_to_tensor (tensor_list )
872
+ if key in self .output_str_dtypes :
873
+ val = self .convert_tensor_to_str_dtype (
874
+ val , self .output_str_dtypes [key ])
875
+ vision_processed_tensors [key ] = val
876
+ return vision_processed_tensors , visual_tokens
877
+
878
+ def llava_onevision_process_video (self , queries , video_bytes = None ):
879
+ import torch
880
+ vision_processed_tensors = {}
881
+ videos = [video for video in self .load_images_tensor (video_bytes )]
882
+
883
+ batch_size = len (videos )
884
+ assert len (
885
+ queries
886
+ ) == batch_size , f"Video must have the same batch size as Query."
887
+ preprocessor_outputs = {}
888
+ preprocessor_outputs ['PIXEL_VALUES' ] = []
889
+ preprocessor_outputs ['IS_VIDEO_INPUT' ] = []
890
+ visual_tokens = []
891
+ for batch_id in range (len (queries )):
892
+ processed_vision_data = self .vision_model_processor (
893
+ videos = list (videos [batch_id ]),
894
+ text = '<video>' ,
895
+ return_tensors = "pt" )
896
+ visual_tokens .append (processed_vision_data ['input_ids' ].shape [1 ])
897
+ preprocessor_outputs ['PIXEL_VALUES' ].append (
898
+ processed_vision_data ['pixel_values_videos' ])
899
+ preprocessor_outputs ['IS_VIDEO_INPUT' ].append (
900
+ torch .ones ((1 , 1 ), dtype = torch .bool ))
901
+
902
+ for key , tensor_list in preprocessor_outputs .items ():
903
+ val = self .convert_tensor_list_to_tensor (tensor_list )
904
+ if key in self .output_str_dtypes :
905
+ val = self .convert_tensor_to_str_dtype (
906
+ val , self .output_str_dtypes [key ])
907
+ vision_processed_tensors [key ] = val
908
+ return vision_processed_tensors , visual_tokens
0 commit comments