Skip to content

Commit 7a56e09

Browse files
authored
Update TensorRT-LLM backend (#663)
1 parent 1791532 commit 7a56e09

29 files changed

+589
-129
lines changed

Diff for: .pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ repos:
77
rev: v1.1.13
88
hooks:
99
- id: remove-crlf
10-
- repo: https://github.com/pre-commit/mirrors-yapf
11-
rev: v0.32.0
10+
- repo: https://github.com/google/yapf
11+
rev: v0.43.0
1212
hooks:
1313
- id: yapf
1414
- repo: https://github.com/pre-commit/pre-commit-hooks

Diff for: all_models/inflight_batcher_llm/ensemble/config.pbtxt

+2-2
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,12 @@ output [
260260
},
261261
{
262262
name: "context_logits"
263-
data_type: TYPE_FP32
263+
data_type: ${logits_datatype}
264264
dims: [ -1, -1 ]
265265
},
266266
{
267267
name: "generation_logits"
268-
data_type: TYPE_FP32
268+
data_type: ${logits_datatype}
269269
dims: [ -1, -1, -1 ]
270270
},
271271
{

Diff for: all_models/inflight_batcher_llm/tensorrt_llm/1/model.py

+159-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,35 @@
2020
METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens"
2121
import tensorrt_llm.logger as logger
2222

23+
# From https://github.com/pytorch/pytorch/blob/39425feac799905402abe4d15667fa47c344f2d7/torch/testing/_internal/common_utils.py#L1761
24+
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
25+
numpy_to_torch_dtype_dict = {
26+
np.bool_: torch.bool,
27+
np.uint8: torch.uint8,
28+
np.uint16: torch.uint16,
29+
np.uint32: torch.uint32,
30+
np.uint64: torch.uint64,
31+
np.int8: torch.int8,
32+
np.int16: torch.int16,
33+
np.int32: torch.int32,
34+
np.int64: torch.int64,
35+
np.float16: torch.float16,
36+
np.float32: torch.float32,
37+
np.float64: torch.float64,
38+
np.complex64: torch.complex64,
39+
np.complex128: torch.complex128
40+
}
41+
42+
# Dict of torch dtype -> NumPy dtype
43+
torch_to_numpy_dtype_dict = {
44+
value: key
45+
for (key, value) in numpy_to_torch_dtype_dict.items()
46+
}
47+
torch_to_numpy_dtype_dict.update({
48+
torch.bfloat16: np.float32,
49+
torch.complex32: np.complex64
50+
})
51+
2352

2453
@dataclass
2554
class RequestData:
@@ -224,7 +253,7 @@ def get_external_draft_tokens_config_from_request(request,
224253
draft_logits = get_input_tensor_by_name(request, 'draft_logits',
225254
batch_size, batch_index)
226255
if draft_logits is not None:
227-
kwargs['logits'] = from_numpy(draft_logits).squeeze()
256+
kwargs['logits'] = from_numpy(draft_logits).squeeze(dim=0)
228257
kwargs['acceptance_threshold'] = get_input_scalar_by_name(
229258
request, 'draft_acceptance_threshold', batch_size, batch_index)
230259
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@@ -279,6 +308,93 @@ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
279308
return None
280309

281310

311+
def get_kv_cache_retention_config_from_request(request,
312+
batch_size=1,
313+
batch_index=0):
314+
315+
def get_tensor_and_check_length(name: str, expected_length: int):
316+
tensor = get_input_tensor_by_name(request, name, batch_size,
317+
batch_index)
318+
319+
if tensor is None:
320+
raise RuntimeError(f"{name} must be provided.")
321+
322+
tensor = np.squeeze(tensor, axis=0)
323+
324+
if len(tensor) != expected_length:
325+
raise RuntimeError(
326+
f"Invalid {name} length. Expected length {expected_length}, got length {len(tensor)}"
327+
)
328+
329+
return tensor
330+
331+
token_range_starts = get_input_tensor_by_name(
332+
request, "retention_token_range_starts", batch_size, batch_index)
333+
334+
if token_range_starts is not None:
335+
token_range_starts = np.squeeze(token_range_starts, axis=0)
336+
337+
token_range_ends = get_tensor_and_check_length(
338+
"retention_token_range_ends", len(token_range_starts))
339+
token_range_ends = [
340+
None if end == -1 else end for end in token_range_ends
341+
]
342+
343+
token_range_priorities = get_tensor_and_check_length(
344+
"retention_token_range_priorities", len(token_range_starts))
345+
346+
token_range_durations_ms = get_input_tensor_by_name(
347+
request, "retention_token_range_durations_ms", batch_size,
348+
batch_index)
349+
350+
if token_range_durations_ms is None:
351+
token_range_durations_ms = [None] * len(token_range_starts)
352+
else:
353+
token_range_durations_ms = np.squeeze(token_range_durations_ms,
354+
axis=0)
355+
token_range_durations_ms = [
356+
None if duration == -1 else duration
357+
for duration in token_range_durations_ms
358+
]
359+
360+
if len(token_range_durations_ms) != len(token_range_starts):
361+
raise RuntimeError(
362+
f"Invalid retention_token_range_durations length. Expected length {len(token_range_starts)}, got length {len(token_range_durations_ms)}"
363+
)
364+
365+
ranges = []
366+
367+
for start, end, priority, duration_ms in zip(token_range_starts,
368+
token_range_ends,
369+
token_range_priorities,
370+
token_range_durations_ms):
371+
ranges.append(
372+
trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig(
373+
token_start=start,
374+
token_end=end,
375+
priority=priority.item(),
376+
duration_ms=None if duration_ms is None else
377+
datetime.timedelta(milliseconds=duration_ms.item())))
378+
379+
decode_args = {}
380+
381+
decode_priority = get_input_scalar_by_name(
382+
request, "retention_decode_priority", batch_size, batch_index)
383+
if decode_priority is not None:
384+
decode_args['decode_retention_priority'] = decode_priority
385+
386+
decode_duration_ms = get_input_scalar_by_name(
387+
request, "retention_decode_duration_ms", batch_size, batch_index)
388+
if decode_duration_ms is not None:
389+
decode_args[
390+
'decode_duration_ms'] = decode_duration_ms if decode_duration_ms != -1 else None
391+
392+
return trtllm.KvCacheRetentionConfig(
393+
token_range_retention_configs=ranges, **decode_args)
394+
395+
return None
396+
397+
282398
def build_1_2_5_buckets(max_value: int) -> List[int]:
283399
"""
284400
Builds a list of buckets with increasing powers of 10 multiplied by
@@ -375,6 +491,8 @@ def convert_request(request, exclude_input_from_output, decoupled):
375491
request, batch_size, batch_index, input_length)
376492
lora_config = get_lora_config_from_request(request, batch_size,
377493
batch_index)
494+
kv_cache_retention_config = get_kv_cache_retention_config_from_request(
495+
request, batch_size, batch_index)
378496

379497
# Inputs for mllama support
380498
encoder_input_features = get_input_tensor_by_name(
@@ -412,11 +530,15 @@ def convert_request(request, exclude_input_from_output, decoupled):
412530
external_draft_tokens_config=external_draft_tokens_config,
413531
prompt_tuning_config=prompt_tuning_config,
414532
lora_config=lora_config,
415-
))
533+
kv_cache_retention_config=kv_cache_retention_config))
416534
return requests
417535

418536

419-
def convert_response(response, batch_index, batch_size, num_return_sequences):
537+
def convert_response(response,
538+
batch_index,
539+
batch_size,
540+
num_return_sequences,
541+
expected_logits_dtype=torch.float32):
420542

421543
if response.has_error():
422544
return pb_utils.InferenceResponse(output_tensors=[],
@@ -450,18 +572,24 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
450572
np.expand_dims(np.array(result.log_probs, np.float32), 0)))
451573

452574
if result.context_logits is not None:
575+
assert (result.context_logits.dtype is expected_logits_dtype)
453576
output_tensors.append(
454577
pb_utils.Tensor(
455578
"context_logits",
456-
np.expand_dims(np.array(result.context_logits, np.float32),
457-
0)))
579+
np.expand_dims(
580+
np.array(
581+
result.context_logits, torch_to_numpy_dtype_dict[
582+
result.context_logits.dtype]), 0)))
458583

459584
if result.generation_logits is not None:
585+
assert (result.generation_logits.dtype is expected_logits_dtype)
460586
output_tensors.append(
461587
pb_utils.Tensor(
462588
"generation_logits",
463-
np.expand_dims(np.array(result.generation_logits, np.float32),
464-
0)))
589+
np.expand_dims(
590+
np.array(
591+
result.generation_logits, torch_to_numpy_dtype_dict[
592+
result.generation_logits.dtype]), 0)))
465593

466594
if batch_size > 1:
467595
output_tensors.append(
@@ -555,6 +683,22 @@ def convert_timestamp_to_seconds(timestamp: str):
555683
"%m-%d-%Y %H:%M:%S.%f").timestamp())
556684

557685

686+
def triton_string_to_torch(dtype):
687+
type_map = {
688+
"TYPE_BOOL": torch.bool,
689+
"TYPE_UINT8": torch.uint8,
690+
"TYPE_INT8": torch.int8,
691+
"TYPE_INT16": torch.int16,
692+
"TYPE_INT32": torch.int32,
693+
"TYPE_INT64": torch.int64,
694+
"TYPE_FP16": torch.float16,
695+
"TYPE_FP32": torch.float32,
696+
"TYPE_FP64": torch.float64,
697+
"TYPE_BF16": torch.bfloat16
698+
}
699+
return type_map[dtype]
700+
701+
558702
class TritonPythonModel:
559703
"""Your Python model must use the same class name. Every Python model
560704
that is created must have "TritonPythonModel" as the class name.
@@ -916,6 +1060,12 @@ def initialize(self, args):
9161060
self.stats_check_period_ms = get_parameter(
9171061
model_config, "stats_check_period_ms", int) or 100
9181062

1063+
self.logits_dtype = None
1064+
for output in model_config['output']:
1065+
if output['name'] == 'context_logits' or output[
1066+
'name'] == 'generation_logits':
1067+
self.logits_dtype = triton_string_to_torch(output['data_type'])
1068+
9191069
self.create_metrics(args["model_name"],
9201070
args["model_version"],
9211071
is_v1_model=executor_config.batching_type ==
@@ -1059,7 +1209,8 @@ def awaiter_loop(self):
10591209

10601210
triton_response, is_final, output_length = convert_response(
10611211
response, request_data.batch_index,
1062-
request_data.batch_size, request_data.num_return_sequences)
1212+
request_data.batch_size, request_data.num_return_sequences,
1213+
self.logits_dtype)
10631214
with self.lock:
10641215
self.req_id_to_request_data[
10651216
req_id].num_output_tokens += output_length

Diff for: all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt

+45-3
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ input [
102102
},
103103
{
104104
name: "draft_logits"
105-
data_type: TYPE_FP32
105+
data_type: ${logits_datatype}
106106
dims: [ -1, -1 ]
107107
optional: true
108108
allow_ragged_batch: true
@@ -388,6 +388,48 @@ input [
388388
dims: [ 1 ]
389389
optional: true
390390
allow_ragged_batch: true
391+
},
392+
{
393+
name: "retention_token_range_starts"
394+
data_type: TYPE_INT32
395+
dims: [ -1 ]
396+
optional: true
397+
allow_ragged_batch: true
398+
},
399+
{
400+
name: "retention_token_range_ends"
401+
data_type: TYPE_INT32
402+
dims: [ -1 ]
403+
optional: true
404+
allow_ragged_batch: true
405+
},
406+
{
407+
name: "retention_token_range_priorities"
408+
data_type: TYPE_INT32
409+
dims: [ -1 ]
410+
optional: true
411+
allow_ragged_batch: true
412+
},
413+
{
414+
name: "retention_token_range_durations_ms"
415+
data_type: TYPE_INT32
416+
dims: [ -1 ]
417+
optional: true
418+
allow_ragged_batch: true
419+
},
420+
{
421+
name: "retention_decode_priority"
422+
data_type: TYPE_INT32
423+
dims: [ 1 ]
424+
optional: true
425+
allow_ragged_batch: true
426+
},
427+
{
428+
name: "retention_decode_duration_ms"
429+
data_type: TYPE_INT32
430+
dims: [ 1 ]
431+
optional: true
432+
allow_ragged_batch: true
391433
}
392434
]
393435
output [
@@ -413,12 +455,12 @@ output [
413455
},
414456
{
415457
name: "context_logits"
416-
data_type: TYPE_FP32
458+
data_type: ${logits_datatype}
417459
dims: [ -1, -1 ]
418460
},
419461
{
420462
name: "generation_logits"
421-
data_type: TYPE_FP32
463+
data_type: ${logits_datatype}
422464
dims: [ -1, -1, -1 ]
423465
},
424466
{

Diff for: all_models/inflight_batcher_llm/tensorrt_llm_bls/1/lib/decode.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@ def _spec_generate(
324324

325325
# Evaluate criteria to stop generation loop.
326326
# If we've hit or exceeded the max output length, should stop
327-
length_stop = (len(input_ids) >=
328-
len(prompt_input_ids) + output_len)
327+
length_stop = (len(input_ids)
328+
>= len(prompt_input_ids) + output_len)
329329
if length_stop:
330330
break
331331
# If draft and target have same outputs, should stop. Normally target should return 1 more token.
@@ -401,9 +401,10 @@ def postprocess(self, gen_response: GenerationResponse,
401401
batch_index = batch_index[0][0] if batch_index is not None else 0
402402

403403
self._accumulated_tokens[batch_index] = new_tokens if (
404-
self._accumulated_tokens[batch_index] is None
405-
) else np.concatenate(
406-
(self._accumulated_tokens[batch_index], new_tokens), axis=2)
404+
self._accumulated_tokens[batch_index]
405+
is None) else np.concatenate(
406+
(self._accumulated_tokens[batch_index], new_tokens),
407+
axis=2)
407408
sequence_lengths = np.array(
408409
[[self._accumulated_tokens[batch_index].shape[2]]],
409410
dtype=np.int32)

Diff for: all_models/inflight_batcher_llm/tensorrt_llm_bls/config.pbtxt

+2-2
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,12 @@ output [
308308
},
309309
{
310310
name: "context_logits"
311-
data_type: TYPE_FP32
311+
data_type: ${logits_datatype}
312312
dims: [ -1, -1 ]
313313
},
314314
{
315315
name: "generation_logits"
316-
data_type: TYPE_FP32
316+
data_type: ${logits_datatype}
317317
dims: [ -1, -1, -1 ]
318318
},
319319
{

0 commit comments

Comments
 (0)