Skip to content

Commit a0b840b

Browse files
committed
chore: Deprecate truncate_long_and_double for the dynamo frontend
`truncate_long_and_double` has been deprecated in favor of `truncate_double` as int64 is natively supported Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent aaab1f9 commit a0b840b

20 files changed

+99
-306
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import warnings
56
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
67

78
import torch
@@ -22,7 +23,7 @@
2223
UnsupportedOperatorException,
2324
convert_module,
2425
interpret_module_to_result,
25-
repair_long_or_double_inputs,
26+
repair_double_inputs,
2627
)
2728
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2829
DYNAMO_CONVERTERS as CONVERTERS,
@@ -58,7 +59,7 @@ def compile(
5859
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
5960
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
6061
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
61-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
62+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
6263
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
6364
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
6465
torch_executed_ops: Optional[Collection[Target]] = None,
@@ -74,7 +75,7 @@ def compile(
7475
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
7576
**kwargs: Any,
7677
) -> torch.fx.GraphModule:
77-
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
78+
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
7879
7980
Takes a existing TorchScript module and a set of settings to configure the compiler
8081
and will convert methods to JIT Graphs which call equivalent TensorRT engines
@@ -115,7 +116,7 @@ def compile(
115116
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
116117
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
117118
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
118-
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
119+
truncate_double (bool): Truncate weights provided in double (float64) to float32
119120
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
120121
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
121122
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
@@ -138,6 +139,19 @@ def compile(
138139
if debug:
139140
set_log_level(logger.parent, logging.DEBUG)
140141

142+
if "truncate_long_and_double" in kwargs.keys():
143+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
144+
raise ValueError(
145+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
146+
)
147+
else:
148+
truncate_double = kwargs["truncate_long_and_double"]
149+
warnings.warn(
150+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
141155
engine_capability = EngineCapability._from(engine_capability)
142156

143157
if torch_executed_modules is not None and torch_executed_modules:
@@ -185,7 +199,7 @@ def compile(
185199
"version_compatible": version_compatible,
186200
"optimization_level": optimization_level,
187201
"use_python_runtime": use_python_runtime,
188-
"truncate_long_and_double": truncate_long_and_double,
202+
"truncate_double": truncate_double,
189203
"use_fast_partitioner": use_fast_partitioner,
190204
"num_avg_timing_iters": num_avg_timing_iters,
191205
"enable_experimental_decompositions": enable_experimental_decompositions,
@@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
349363

350364
assert submodule_inputs is not None
351365
# Handle long/double inputs if requested by the user
352-
if settings.truncate_long_and_double:
353-
submodule_inputs = repair_long_or_double_inputs(
366+
if settings.truncate_double:
367+
submodule_inputs = repair_double_inputs(
354368
partitioned_module,
355369
submodule,
356370
submodule_inputs,
@@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
423437

424438
def convert_module_to_trt_engine(
425439
exported_program: ExportedProgram,
426-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
440+
inputs: Tuple[Any, ...],
441+
*,
427442
enabled_precisions: (
428443
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
429444
) = _defaults.ENABLED_PRECISIONS,
@@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
436451
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
437452
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
438453
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
439-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
454+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
440455
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
441456
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
442457
device: Device = Device._current_device(),
@@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
451466
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
452467
calibrator: object = None,
453468
allow_shape_tensors: bool = False,
469+
**kwargs: Any,
454470
) -> bytes:
455471
"""Convert an ExportedProgram to a serialized TensorRT engine
456472
@@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
488504
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
489505
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
490506
argument as None
491-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
507+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
492508
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
493509
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
494510
or only a selected subset of them
@@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
512528
if debug:
513529
set_log_level(logger.parent, logging.DEBUG)
514530

531+
if "truncate_long_and_double" in kwargs.keys():
532+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
533+
raise ValueError(
534+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
535+
)
536+
else:
537+
truncate_double = kwargs["truncate_long_and_double"]
538+
warnings.warn(
539+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
540+
DeprecationWarning,
541+
stacklevel=2,
542+
)
543+
515544
input_list = list(inputs) if inputs is not None else []
516545
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
517546
# Prepare torch_trt inputs
@@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
531560
"version_compatible": version_compatible,
532561
"optimization_level": optimization_level,
533562
"use_python_runtime": use_python_runtime,
534-
"truncate_long_and_double": truncate_long_and_double,
563+
"truncate_double": truncate_double,
535564
"use_fast_partitioner": use_fast_partitioner,
536565
"enable_experimental_decompositions": enable_experimental_decompositions,
537566
"device": device,

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
VERSION_COMPATIBLE = False
1919
OPTIMIZATION_LEVEL = None
2020
SPARSE_WEIGHTS = False
21-
TRUNCATE_LONG_AND_DOUBLE = False
21+
TRUNCATE_DOUBLE = False
2222
USE_PYTHON_RUNTIME = False
2323
USE_FAST_PARTITIONER = True
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
REFIT,
2424
REQUIRE_FULL_COMPILATION,
2525
SPARSE_WEIGHTS,
26-
TRUNCATE_LONG_AND_DOUBLE,
26+
TRUNCATE_DOUBLE,
2727
USE_FAST_PARTITIONER,
2828
USE_PYTHON_RUNTIME,
2929
VERSION_COMPATIBLE,
@@ -50,7 +50,7 @@ class CompilationSettings:
5050
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
5151
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
5252
argument as None
53-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
53+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
5454
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
5555
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
5656
or only a selected subset of them
@@ -81,7 +81,7 @@ class CompilationSettings:
8181
version_compatible: bool = VERSION_COMPATIBLE
8282
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
8383
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
84-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
84+
truncate_double: bool = TRUNCATE_DOUBLE
8585
use_fast_partitioner: bool = USE_FAST_PARTITIONER
8686
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
8787
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/conversion/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403
6-
from .truncate_long_and_double import repair_long_or_double_inputs
6+
from .truncate_double import repair_double_inputs

py/torch_tensorrt/dynamo/conversion/_conversion.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from typing import List, Sequence
66

7-
import tensorrt as trt
87
import torch
98
from torch_tensorrt._Device import Device
109
from torch_tensorrt._enums import dtype
@@ -18,14 +17,16 @@
1817
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1918
from torch_tensorrt.dynamo.utils import get_torch_inputs
2019

20+
import tensorrt as trt
21+
2122
logger = logging.getLogger(__name__)
2223

2324

2425
def infer_module_output_dtypes(
2526
module: torch.fx.GraphModule,
2627
inputs: Sequence[Input],
2728
device: Device,
28-
truncate_long_and_double: bool = False,
29+
truncate_double: bool = False,
2930
) -> List[dtype]:
3031
torch_inputs = get_torch_inputs(inputs, device)
3132
module = module.to(device.to(torch.device))
@@ -48,10 +49,8 @@ def infer_module_output_dtypes(
4849
else:
4950
output_ = torch.tensor(output)
5051

51-
if truncate_long_and_double and output_.dtype == dtype.float64:
52+
if truncate_double and output_.dtype == dtype.float64:
5253
output_dtypes.append(dtype.float32)
53-
elif truncate_long_and_double and output_.dtype == dtype.int64:
54-
output_dtypes.append(dtype.int32)
5554
else:
5655
output_dtypes.append(dtype._from(output_.dtype))
5756

@@ -75,7 +74,7 @@ def interpret_module_to_result(
7574
module,
7675
inputs,
7776
settings.device,
78-
truncate_long_and_double=settings.truncate_long_and_double,
77+
truncate_double=settings.truncate_double,
7978
)
8079

8180
interpreter = TRTInterpreter(

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,11 @@ def get_trt_tensor(
297297
A TensorRT ITensor that represents the given value.
298298
"""
299299
# If the input is 64-bit, cast it to 32-bit for TRT freezing
300-
if (
301-
isinstance(input_val, torch.Tensor)
302-
and ctx.compilation_settings.truncate_long_and_double
303-
):
304-
if input_val.dtype == torch.int64:
305-
input_val = input_val.to(torch.int32)
306-
elif input_val.dtype == torch.float64:
300+
if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double:
301+
if input_val.dtype == torch.float64:
307302
input_val = input_val.to(torch.float32)
308-
elif (
309-
isinstance(input_val, np.ndarray)
310-
and ctx.compilation_settings.truncate_long_and_double
311-
):
312-
if input_val.dtype == np.int64:
313-
input_val = input_val.astype(np.int32)
314-
elif input_val.dtype == np.float64:
303+
elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double:
304+
if input_val.dtype == np.float64:
315305
input_val = input_val.astype(np.float32)
316306

317307
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,22 @@
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
1111
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
12-
from torch_tensorrt.fx.types import TRTTensor
12+
13+
import tensorrt as trt
1314

1415

1516
def embedding(
1617
ctx: ConversionContext,
1718
target: Target,
1819
source_ir: Optional[SourceIR],
1920
name: str,
20-
input: TRTTensor,
21-
weight: TRTTensor,
21+
input: trt.ITensor,
22+
weight: trt.ITensor,
2223
scale_grad_by_freq: bool,
2324
sparse: bool,
24-
) -> TRTTensor:
25+
) -> trt.ITensor:
2526
indices_tensor = input
2627
embedding_tensor = weight
27-
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
28-
raise RuntimeError(
29-
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
30-
)
3128
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
3229
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
3330
# unsupported parameters
@@ -52,15 +49,15 @@ def embedding_bag(
5249
target: Target,
5350
source_ir: Optional[SourceIR],
5451
name: str,
55-
weight: TRTTensor,
56-
indices: TRTTensor,
52+
weight: trt.ITensor,
53+
indices: trt.ITensor,
5754
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
5855
scale_grad_by_freq: bool,
5956
mode: int,
6057
sparse: bool,
61-
per_sample_weights: Optional[TRTTensor],
58+
per_sample_weights: Optional[trt.ITensor],
6259
include_last_offset: bool,
63-
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
60+
) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]:
6461
"""
6562
This function is for calculating embedding bags.
6663
@@ -143,7 +140,7 @@ def embedding_bag(
143140
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
144141
# is equal to the number of bags + 1. The last element is the size of the input,
145142
# or the ending index position of the last bag (sequence).
146-
offsets[-1] = indices.shape[0]
143+
offsets[-1] = indices.shape[0] # type: ignore[index]
147144

148145
# separately reduce embeddings for different bags
149146
reduced_embed = []

0 commit comments

Comments
 (0)