|
20 | 20 | METRIC_TOTAL_INPUT_TOKENS = "total_input_tokens"
|
21 | 21 | import tensorrt_llm.logger as logger
|
22 | 22 |
|
| 23 | +# From https://github.com/pytorch/pytorch/blob/39425feac799905402abe4d15667fa47c344f2d7/torch/testing/_internal/common_utils.py#L1761 |
| 24 | +# Dict of NumPy dtype -> torch dtype (when the correspondence exists) |
| 25 | +numpy_to_torch_dtype_dict = { |
| 26 | + np.bool_: torch.bool, |
| 27 | + np.uint8: torch.uint8, |
| 28 | + np.uint16: torch.uint16, |
| 29 | + np.uint32: torch.uint32, |
| 30 | + np.uint64: torch.uint64, |
| 31 | + np.int8: torch.int8, |
| 32 | + np.int16: torch.int16, |
| 33 | + np.int32: torch.int32, |
| 34 | + np.int64: torch.int64, |
| 35 | + np.float16: torch.float16, |
| 36 | + np.float32: torch.float32, |
| 37 | + np.float64: torch.float64, |
| 38 | + np.complex64: torch.complex64, |
| 39 | + np.complex128: torch.complex128 |
| 40 | +} |
| 41 | + |
| 42 | +# Dict of torch dtype -> NumPy dtype |
| 43 | +torch_to_numpy_dtype_dict = { |
| 44 | + value: key |
| 45 | + for (key, value) in numpy_to_torch_dtype_dict.items() |
| 46 | +} |
| 47 | +torch_to_numpy_dtype_dict.update({ |
| 48 | + torch.bfloat16: np.float32, |
| 49 | + torch.complex32: np.complex64 |
| 50 | +}) |
| 51 | + |
23 | 52 |
|
24 | 53 | @dataclass
|
25 | 54 | class RequestData:
|
@@ -224,7 +253,7 @@ def get_external_draft_tokens_config_from_request(request,
|
224 | 253 | draft_logits = get_input_tensor_by_name(request, 'draft_logits',
|
225 | 254 | batch_size, batch_index)
|
226 | 255 | if draft_logits is not None:
|
227 |
| - kwargs['logits'] = from_numpy(draft_logits).squeeze() |
| 256 | + kwargs['logits'] = from_numpy(draft_logits).squeeze(dim=0) |
228 | 257 | kwargs['acceptance_threshold'] = get_input_scalar_by_name(
|
229 | 258 | request, 'draft_acceptance_threshold', batch_size, batch_index)
|
230 | 259 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
@@ -279,6 +308,93 @@ def get_lora_config_from_request(request, batch_size=1, batch_index=0):
|
279 | 308 | return None
|
280 | 309 |
|
281 | 310 |
|
| 311 | +def get_kv_cache_retention_config_from_request(request, |
| 312 | + batch_size=1, |
| 313 | + batch_index=0): |
| 314 | + |
| 315 | + def get_tensor_and_check_length(name: str, expected_length: int): |
| 316 | + tensor = get_input_tensor_by_name(request, name, batch_size, |
| 317 | + batch_index) |
| 318 | + |
| 319 | + if tensor is None: |
| 320 | + raise RuntimeError(f"{name} must be provided.") |
| 321 | + |
| 322 | + tensor = np.squeeze(tensor, axis=0) |
| 323 | + |
| 324 | + if len(tensor) != expected_length: |
| 325 | + raise RuntimeError( |
| 326 | + f"Invalid {name} length. Expected length {expected_length}, got length {len(tensor)}" |
| 327 | + ) |
| 328 | + |
| 329 | + return tensor |
| 330 | + |
| 331 | + token_range_starts = get_input_tensor_by_name( |
| 332 | + request, "retention_token_range_starts", batch_size, batch_index) |
| 333 | + |
| 334 | + if token_range_starts is not None: |
| 335 | + token_range_starts = np.squeeze(token_range_starts, axis=0) |
| 336 | + |
| 337 | + token_range_ends = get_tensor_and_check_length( |
| 338 | + "retention_token_range_ends", len(token_range_starts)) |
| 339 | + token_range_ends = [ |
| 340 | + None if end == -1 else end for end in token_range_ends |
| 341 | + ] |
| 342 | + |
| 343 | + token_range_priorities = get_tensor_and_check_length( |
| 344 | + "retention_token_range_priorities", len(token_range_starts)) |
| 345 | + |
| 346 | + token_range_durations_ms = get_input_tensor_by_name( |
| 347 | + request, "retention_token_range_durations_ms", batch_size, |
| 348 | + batch_index) |
| 349 | + |
| 350 | + if token_range_durations_ms is None: |
| 351 | + token_range_durations_ms = [None] * len(token_range_starts) |
| 352 | + else: |
| 353 | + token_range_durations_ms = np.squeeze(token_range_durations_ms, |
| 354 | + axis=0) |
| 355 | + token_range_durations_ms = [ |
| 356 | + None if duration == -1 else duration |
| 357 | + for duration in token_range_durations_ms |
| 358 | + ] |
| 359 | + |
| 360 | + if len(token_range_durations_ms) != len(token_range_starts): |
| 361 | + raise RuntimeError( |
| 362 | + f"Invalid retention_token_range_durations length. Expected length {len(token_range_starts)}, got length {len(token_range_durations_ms)}" |
| 363 | + ) |
| 364 | + |
| 365 | + ranges = [] |
| 366 | + |
| 367 | + for start, end, priority, duration_ms in zip(token_range_starts, |
| 368 | + token_range_ends, |
| 369 | + token_range_priorities, |
| 370 | + token_range_durations_ms): |
| 371 | + ranges.append( |
| 372 | + trtllm.KvCacheRetentionConfig.TokenRangeRetentionConfig( |
| 373 | + token_start=start, |
| 374 | + token_end=end, |
| 375 | + priority=priority.item(), |
| 376 | + duration_ms=None if duration_ms is None else |
| 377 | + datetime.timedelta(milliseconds=duration_ms.item()))) |
| 378 | + |
| 379 | + decode_args = {} |
| 380 | + |
| 381 | + decode_priority = get_input_scalar_by_name( |
| 382 | + request, "retention_decode_priority", batch_size, batch_index) |
| 383 | + if decode_priority is not None: |
| 384 | + decode_args['decode_retention_priority'] = decode_priority |
| 385 | + |
| 386 | + decode_duration_ms = get_input_scalar_by_name( |
| 387 | + request, "retention_decode_duration_ms", batch_size, batch_index) |
| 388 | + if decode_duration_ms is not None: |
| 389 | + decode_args[ |
| 390 | + 'decode_duration_ms'] = decode_duration_ms if decode_duration_ms != -1 else None |
| 391 | + |
| 392 | + return trtllm.KvCacheRetentionConfig( |
| 393 | + token_range_retention_configs=ranges, **decode_args) |
| 394 | + |
| 395 | + return None |
| 396 | + |
| 397 | + |
282 | 398 | def build_1_2_5_buckets(max_value: int) -> List[int]:
|
283 | 399 | """
|
284 | 400 | Builds a list of buckets with increasing powers of 10 multiplied by
|
@@ -375,6 +491,8 @@ def convert_request(request, exclude_input_from_output, decoupled):
|
375 | 491 | request, batch_size, batch_index, input_length)
|
376 | 492 | lora_config = get_lora_config_from_request(request, batch_size,
|
377 | 493 | batch_index)
|
| 494 | + kv_cache_retention_config = get_kv_cache_retention_config_from_request( |
| 495 | + request, batch_size, batch_index) |
378 | 496 |
|
379 | 497 | # Inputs for mllama support
|
380 | 498 | encoder_input_features = get_input_tensor_by_name(
|
@@ -412,11 +530,15 @@ def convert_request(request, exclude_input_from_output, decoupled):
|
412 | 530 | external_draft_tokens_config=external_draft_tokens_config,
|
413 | 531 | prompt_tuning_config=prompt_tuning_config,
|
414 | 532 | lora_config=lora_config,
|
415 |
| - )) |
| 533 | + kv_cache_retention_config=kv_cache_retention_config)) |
416 | 534 | return requests
|
417 | 535 |
|
418 | 536 |
|
419 |
| -def convert_response(response, batch_index, batch_size, num_return_sequences): |
| 537 | +def convert_response(response, |
| 538 | + batch_index, |
| 539 | + batch_size, |
| 540 | + num_return_sequences, |
| 541 | + expected_logits_dtype=torch.float32): |
420 | 542 |
|
421 | 543 | if response.has_error():
|
422 | 544 | return pb_utils.InferenceResponse(output_tensors=[],
|
@@ -450,18 +572,24 @@ def convert_response(response, batch_index, batch_size, num_return_sequences):
|
450 | 572 | np.expand_dims(np.array(result.log_probs, np.float32), 0)))
|
451 | 573 |
|
452 | 574 | if result.context_logits is not None:
|
| 575 | + assert (result.context_logits.dtype is expected_logits_dtype) |
453 | 576 | output_tensors.append(
|
454 | 577 | pb_utils.Tensor(
|
455 | 578 | "context_logits",
|
456 |
| - np.expand_dims(np.array(result.context_logits, np.float32), |
457 |
| - 0))) |
| 579 | + np.expand_dims( |
| 580 | + np.array( |
| 581 | + result.context_logits, torch_to_numpy_dtype_dict[ |
| 582 | + result.context_logits.dtype]), 0))) |
458 | 583 |
|
459 | 584 | if result.generation_logits is not None:
|
| 585 | + assert (result.generation_logits.dtype is expected_logits_dtype) |
460 | 586 | output_tensors.append(
|
461 | 587 | pb_utils.Tensor(
|
462 | 588 | "generation_logits",
|
463 |
| - np.expand_dims(np.array(result.generation_logits, np.float32), |
464 |
| - 0))) |
| 589 | + np.expand_dims( |
| 590 | + np.array( |
| 591 | + result.generation_logits, torch_to_numpy_dtype_dict[ |
| 592 | + result.generation_logits.dtype]), 0))) |
465 | 593 |
|
466 | 594 | if batch_size > 1:
|
467 | 595 | output_tensors.append(
|
@@ -555,6 +683,22 @@ def convert_timestamp_to_seconds(timestamp: str):
|
555 | 683 | "%m-%d-%Y %H:%M:%S.%f").timestamp())
|
556 | 684 |
|
557 | 685 |
|
| 686 | +def triton_string_to_torch(dtype): |
| 687 | + type_map = { |
| 688 | + "TYPE_BOOL": torch.bool, |
| 689 | + "TYPE_UINT8": torch.uint8, |
| 690 | + "TYPE_INT8": torch.int8, |
| 691 | + "TYPE_INT16": torch.int16, |
| 692 | + "TYPE_INT32": torch.int32, |
| 693 | + "TYPE_INT64": torch.int64, |
| 694 | + "TYPE_FP16": torch.float16, |
| 695 | + "TYPE_FP32": torch.float32, |
| 696 | + "TYPE_FP64": torch.float64, |
| 697 | + "TYPE_BF16": torch.bfloat16 |
| 698 | + } |
| 699 | + return type_map[dtype] |
| 700 | + |
| 701 | + |
558 | 702 | class TritonPythonModel:
|
559 | 703 | """Your Python model must use the same class name. Every Python model
|
560 | 704 | that is created must have "TritonPythonModel" as the class name.
|
@@ -916,6 +1060,12 @@ def initialize(self, args):
|
916 | 1060 | self.stats_check_period_ms = get_parameter(
|
917 | 1061 | model_config, "stats_check_period_ms", int) or 100
|
918 | 1062 |
|
| 1063 | + self.logits_dtype = None |
| 1064 | + for output in model_config['output']: |
| 1065 | + if output['name'] == 'context_logits' or output[ |
| 1066 | + 'name'] == 'generation_logits': |
| 1067 | + self.logits_dtype = triton_string_to_torch(output['data_type']) |
| 1068 | + |
919 | 1069 | self.create_metrics(args["model_name"],
|
920 | 1070 | args["model_version"],
|
921 | 1071 | is_v1_model=executor_config.batching_type ==
|
@@ -1059,7 +1209,8 @@ def awaiter_loop(self):
|
1059 | 1209 |
|
1060 | 1210 | triton_response, is_final, output_length = convert_response(
|
1061 | 1211 | response, request_data.batch_index,
|
1062 |
| - request_data.batch_size, request_data.num_return_sequences) |
| 1212 | + request_data.batch_size, request_data.num_return_sequences, |
| 1213 | + self.logits_dtype) |
1063 | 1214 | with self.lock:
|
1064 | 1215 | self.req_id_to_request_data[
|
1065 | 1216 | req_id].num_output_tokens += output_length
|
|
0 commit comments