Skip to content

Commit 1791532

Browse files
authored
Update TensorRT-LLM backend (#658)
1 parent 87100b0 commit 1791532

33 files changed

+1262
-156
lines changed

Diff for: all_models/disaggregated_serving/disaggregated_serving_bls/config.pbtxt

+22
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ input [
229229
dims: [ 1 ]
230230
optional: true
231231
},
232+
{
233+
name: "return_kv_cache_reuse_stats"
234+
data_type: TYPE_BOOL
235+
dims: [ 1 ]
236+
reshape: { shape: [ ] }
237+
optional: true
238+
},
232239
{
233240
name: "exclude_input_in_output"
234241
data_type: TYPE_BOOL
@@ -349,6 +356,21 @@ output [
349356
name: "sequence_index"
350357
data_type: TYPE_INT32
351358
dims: [ 1 ]
359+
},
360+
{
361+
name: "kv_cache_alloc_new_blocks"
362+
data_type: TYPE_INT32
363+
dims: [ 1 ]
364+
},
365+
{
366+
name: "kv_cache_reused_blocks"
367+
data_type: TYPE_INT32
368+
dims: [ 1 ]
369+
},
370+
{
371+
name: "kv_cache_alloc_total_blocks"
372+
data_type: TYPE_INT32
373+
dims: [ 1 ]
352374
}
353375
]
354376
instance_group [

Diff for: all_models/inflight_batcher_llm/ensemble/config.pbtxt

+37
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ input [
152152
dims: [ 1 ]
153153
optional: true
154154
},
155+
{
156+
name: "return_kv_cache_reuse_stats"
157+
data_type: TYPE_BOOL
158+
dims: [ 1 ]
159+
optional: true
160+
},
155161
{
156162
name: "beam_width"
157163
data_type: TYPE_INT32
@@ -271,6 +277,21 @@ output [
271277
name: "sequence_index"
272278
data_type: TYPE_INT32
273279
dims: [ 1 ]
280+
},
281+
{
282+
name: "kv_cache_alloc_new_blocks"
283+
data_type: TYPE_INT32
284+
dims: [ 1 ]
285+
},
286+
{
287+
name: "kv_cache_reused_blocks"
288+
data_type: TYPE_INT32
289+
dims: [ 1 ]
290+
},
291+
{
292+
name: "kv_cache_alloc_total_blocks"
293+
data_type: TYPE_INT32
294+
dims: [ 1 ]
274295
}
275296
]
276297
ensemble_scheduling {
@@ -450,6 +471,10 @@ ensemble_scheduling {
450471
key: "return_generation_logits"
451472
value: "return_generation_logits"
452473
}
474+
input_map {
475+
key: "return_kv_cache_reuse_stats"
476+
value: "return_kv_cache_reuse_stats"
477+
}
453478
input_map {
454479
key: "num_return_sequences"
455480
value: "num_return_sequences"
@@ -525,6 +550,18 @@ ensemble_scheduling {
525550
output_map {
526551
key: "sequence_index"
527552
value: "sequence_index"
553+
},
554+
output_map {
555+
key: "kv_cache_alloc_new_blocks"
556+
value: "kv_cache_alloc_new_blocks"
557+
},
558+
output_map {
559+
key: "kv_cache_reused_blocks"
560+
value: "kv_cache_reused_blocks"
561+
},
562+
output_map {
563+
key: "kv_cache_alloc_total_blocks"
564+
value: "kv_cache_alloc_total_blocks"
528565
}
529566
},
530567
{

Diff for: all_models/inflight_batcher_llm/preprocessing/1/model.py

+166-34
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ def initialize(self, args):
134134
'model_type']
135135

136136
assert self.model_type in [
137-
'llava', 'blip2-opt', 'vila', 'mllama'
138-
], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila and mllama. Got {self.model_type}."
137+
'llava', 'blip2-opt', 'vila', 'mllama', 'llava_onevision'
138+
], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila, mllama and llava_onevision. Got {self.model_type}."
139+
140+
assert self.model_type != 'llava_onevison' or self.max_num_images is None or self.max_num_images <= 1, f"LLaVA-OneVsion is not support multi image inference currently."
139141

140142
llm_model_path = model_config['parameters']['gpt_model_path'][
141143
'string_value']
@@ -146,15 +148,17 @@ def initialize(self, args):
146148
llm_model_config["pretrained_config"]["vocab_size"])
147149
self._setup_ptable_shape(llm_model_config)
148150

149-
self.vision_preprocessor = VisionPreProcessor(
150-
self.model_type, AutoProcessor.from_pretrained(tokenizer_dir),
151-
model_config)
151+
if self.model_type == 'mllama' or self.model_type == 'llava_onevision':
152+
self.vision_preprocessor = VisionPreProcessor(
153+
self.model_type,
154+
AutoProcessor.from_pretrained(tokenizer_dir), model_config)
152155

153156
# Parse model output configs and convert Triton types to numpy types
154157
output_names = [
155158
"INPUT_ID", "DECODER_INPUT_ID", "REQUEST_INPUT_LEN",
156159
"REQUEST_DECODER_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS",
157-
"OUT_END_ID", "OUT_PAD_ID", "OUT_PROMPT_TABLE_EXTRA_IDS"
160+
"OUT_END_ID", "OUT_PAD_ID", "OUT_PROMPT_TABLE_EXTRA_IDS",
161+
"PIXEL_VALUES", "IMAGE_SIZES"
158162
]
159163
input_names = ["EMBEDDING_BIAS_WORDS", "EMBEDDING_BIAS_WEIGHTS"]
160164
for input_name in input_names:
@@ -270,8 +274,50 @@ def execute(self, requests):
270274
assert prompt_table_extra_id.shape[
271275
1] == 1, "Multiple IDs cannot be provided for a single image"
272276

277+
# Preprocessing vision input passed as a url or bytes tensor
278+
img_urls = pb_utils.get_input_tensor_by_name(request, 'IMAGE_URL')
279+
image_bytes = pb_utils.get_input_tensor_by_name(
280+
request, 'IMAGE_BYTES')
281+
video_bytes = pb_utils.get_input_tensor_by_name(
282+
request, 'VIDEO_BYTES')
283+
vision_processed_tensors = []
284+
visual_tokens = []
285+
if self.is_multimodal and (img_urls or image_bytes or video_bytes):
286+
assert self.vision_preprocessor != None, "Vision preprocessor for preparing images before encoding is None"
287+
processed_tensors = {}
288+
if self.model_type == 'mllama':
289+
processed_tensors = self.vision_preprocessor.mllama_process(
290+
queries=query.astype(str).tolist(),
291+
img_urls=img_urls,
292+
image_bytes=image_bytes,
293+
)
294+
elif self.model_type == 'llava_onevision':
295+
if video_bytes is None:
296+
processed_tensors, visual_tokens = self.vision_preprocessor.llava_onevision_process_image(
297+
queries=query.astype(str).tolist(),
298+
img_urls=img_urls,
299+
image_bytes=image_bytes,
300+
)
301+
else:
302+
processed_tensors, visual_tokens = self.vision_preprocessor.llava_onevision_process_video(
303+
queries=query.astype(str).tolist(),
304+
video_bytes=video_bytes,
305+
)
306+
else:
307+
raise ValueError(
308+
"Unsupported model type for IMAGE_BYTES or IMAGE_URL inputs"
309+
)
310+
vision_processed_tensors = [
311+
pb_utils.Tensor.from_dlpack(k, v)
312+
for k, v in processed_tensors.items()
313+
]
314+
else:
315+
assert self.model_type != "mllama" and self.model_type != "llava_onevision", "Image processing requires IMAGE_BYTES or IMAGE_URL to be provided"
316+
273317
# Preprocessing input data.
274-
input_id, request_input_len = self._create_request(query)
318+
# For the LLaVA_OneVision model, num_visual_features is not a fixed value
319+
input_id, request_input_len = self._create_request(
320+
query, visual_tokens)
275321
if decoder_query is not None:
276322
decoder_input_id, request_decoder_input_len = self._create_request(
277323
decoder_query)
@@ -294,24 +340,6 @@ def execute(self, requests):
294340
input_id[i] >= self.vocab_size,
295341
prompt_table_extra_id[i], 0)
296342

297-
# Preprocessing vision input passed as a url or bytes tensor
298-
img_urls = pb_utils.get_input_tensor_by_name(request, 'IMAGE_URL')
299-
image_bytes = pb_utils.get_input_tensor_by_name(
300-
request, 'IMAGE_BYTES')
301-
if img_urls or image_bytes:
302-
assert self.vision_preprocessor != None, "Vision preprocessor for preparing images before encoding is None"
303-
vision_processed_tensors = self.vision_preprocessor.process(
304-
queries=query.astype(str).tolist(),
305-
img_urls=img_urls,
306-
image_bytes=image_bytes,
307-
) if self.is_multimodal else {}
308-
vision_processed_tensors = [
309-
pb_utils.Tensor.from_dlpack(k, v)
310-
for k, v in vision_processed_tensors.items()
311-
]
312-
else:
313-
vision_processed_tensors = []
314-
315343
# Create output tensors. You need pb_utils.Tensor
316344
# objects to create pb_utils.InferenceResponse.
317345
input_id_tensor = pb_utils.Tensor(
@@ -489,7 +517,7 @@ def _process_multi_image_inputs(self, query, image_token_index=-200):
489517

490518
return start_ids
491519

492-
def _create_request(self, query):
520+
def _create_request(self, query, visual_tokens=None):
493521
"""
494522
query : batch string (2D numpy array)
495523
"""
@@ -513,7 +541,7 @@ def _create_request(self, query):
513541
]
514542

515543
if self.is_multimodal:
516-
if 'blip2' in self.model_type:
544+
if 'blip2' in self.model_type or 'mllama' == self.model_type:
517545
pre_prompt = None
518546
post_prompt = None
519547
elif 'llava' == self.model_type:
@@ -522,12 +550,9 @@ def _create_request(self, query):
522550
elif 'vila' == self.model_type:
523551
pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
524552
post_prompt = " ASSISTANT:"
525-
elif 'mllama' == self.model_type:
526-
pre_prompt = None
527-
post_prompt = None
528-
529-
fake_prompt_id = np.arange(self.vocab_size,
530-
self.vocab_size + self.ptable_shape[1])
553+
elif 'llava_onevision' == self.model_type:
554+
pre_prompt = "<|im_start|>user "
555+
post_prompt = "<|im_end|><|im_start|>assistant\n"
531556

532557
pre_prompt_id = np.array(
533558
self.tokenizer.encode(
@@ -552,7 +577,26 @@ def _create_request(self, query):
552577
concatenated_ids)
553578
start_ids = self._setup_fake_prompts(query.shape[0],
554579
batch_split_prompts)
580+
elif self.model_type == 'llava_onevision':
581+
fake_prompt_ids = []
582+
extra_id = np.array(
583+
self.tokenizer.encode(
584+
'\n',
585+
add_special_tokens=self.add_special_tokens,
586+
padding=True))
587+
for tokens in visual_tokens:
588+
prompt_id = np.arange(self.vocab_size,
589+
self.vocab_size + tokens)
590+
fake_prompt_ids.append(prompt_id)
591+
start_ids = [
592+
np.concatenate((pre_prompt_id, prompt_id, extra_id, ids,
593+
post_prompt_id),
594+
axis=0)
595+
for prompt_id, ids in zip(fake_prompt_ids, start_ids)
596+
]
555597
else:
598+
fake_prompt_id = np.arange(
599+
self.vocab_size, self.vocab_size + self.ptable_shape[1])
556600
start_ids = [
557601
np.concatenate(
558602
(pre_prompt_id, fake_prompt_id, ids, post_prompt_id),
@@ -725,7 +769,7 @@ def load_images_from_urls(self, img_urls):
725769
Image.open(requests.get(img_url, stream=True).raw))
726770
return images
727771

728-
def process(self, queries, img_urls=None, image_bytes=None):
772+
def mllama_process(self, queries, img_urls=None, image_bytes=None):
729773
vision_processed_tensors = {}
730774
if img_urls is not None or image_bytes is not None:
731775
if img_urls is not None:
@@ -774,3 +818,91 @@ def process(self, queries, img_urls=None, image_bytes=None):
774818
val, self.output_str_dtypes[key])
775819
vision_processed_tensors[key] = val
776820
return vision_processed_tensors
821+
822+
def llava_onevision_process_image(self,
823+
queries,
824+
img_urls=None,
825+
image_bytes=None):
826+
827+
import torch
828+
vision_processed_tensors = {}
829+
if img_urls is not None:
830+
# download and read images
831+
images = [
832+
self.load_images_from_urls(urls)
833+
for urls in img_urls.as_numpy()
834+
]
835+
else:
836+
images = [
837+
img for img_list in self.load_images_tensor(image_bytes)
838+
for img in img_list
839+
]
840+
841+
batch_size = len(images)
842+
assert len(
843+
queries
844+
) == batch_size, f"Image must have the same batch size as Query."
845+
preprocessor_outputs = {}
846+
possible_output_names = ['PIXEL_VALUES', 'IMAGE_SIZES']
847+
visual_tokens = []
848+
for batch_id in range(batch_size):
849+
# Preprocess images and query
850+
processed_vision_data = self.vision_model_processor(
851+
images=images[batch_id], text='<image>', return_tensors="pt")
852+
visual_tokens.append(processed_vision_data['input_ids'].shape[1])
853+
854+
# Create vision output tensors
855+
for key in possible_output_names:
856+
val = processed_vision_data.get(key.lower())
857+
if val is not None:
858+
if key not in preprocessor_outputs:
859+
preprocessor_outputs[key] = []
860+
preprocessor_outputs[key].append(val)
861+
862+
max_patch = max(x.shape[1]
863+
for x in preprocessor_outputs['PIXEL_VALUES'])
864+
preprocessor_outputs['PIXEL_VALUES'] = [
865+
torch.nn.functional.pad(
866+
image, (0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[1], 0, 0),
867+
mode='constant')
868+
for image in preprocessor_outputs['PIXEL_VALUES']
869+
]
870+
for key, tensor_list in preprocessor_outputs.items():
871+
val = self.convert_tensor_list_to_tensor(tensor_list)
872+
if key in self.output_str_dtypes:
873+
val = self.convert_tensor_to_str_dtype(
874+
val, self.output_str_dtypes[key])
875+
vision_processed_tensors[key] = val
876+
return vision_processed_tensors, visual_tokens
877+
878+
def llava_onevision_process_video(self, queries, video_bytes=None):
879+
import torch
880+
vision_processed_tensors = {}
881+
videos = [video for video in self.load_images_tensor(video_bytes)]
882+
883+
batch_size = len(videos)
884+
assert len(
885+
queries
886+
) == batch_size, f"Video must have the same batch size as Query."
887+
preprocessor_outputs = {}
888+
preprocessor_outputs['PIXEL_VALUES'] = []
889+
preprocessor_outputs['IS_VIDEO_INPUT'] = []
890+
visual_tokens = []
891+
for batch_id in range(len(queries)):
892+
processed_vision_data = self.vision_model_processor(
893+
videos=list(videos[batch_id]),
894+
text='<video>',
895+
return_tensors="pt")
896+
visual_tokens.append(processed_vision_data['input_ids'].shape[1])
897+
preprocessor_outputs['PIXEL_VALUES'].append(
898+
processed_vision_data['pixel_values_videos'])
899+
preprocessor_outputs['IS_VIDEO_INPUT'].append(
900+
torch.ones((1, 1), dtype=torch.bool))
901+
902+
for key, tensor_list in preprocessor_outputs.items():
903+
val = self.convert_tensor_list_to_tensor(tensor_list)
904+
if key in self.output_str_dtypes:
905+
val = self.convert_tensor_to_str_dtype(
906+
val, self.output_str_dtypes[key])
907+
vision_processed_tensors[key] = val
908+
return vision_processed_tensors, visual_tokens

0 commit comments

Comments
 (0)