Skip to content

Commit 77df69d

Browse files
committed
chore: Deprecate truncate_long_and_double for the dynamo frontend
`truncate_long_and_double` has been deprecated in favor of `truncate_double` as int64 is natively supported Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d4e59b1 commit 77df69d

21 files changed

+99
-82
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
- --fix=lf
1717
exclude: ^docs
1818
- repo: https://github.com/pre-commit/mirrors-clang-format
19-
rev: v18.1.1
19+
rev: v14.0.6
2020
hooks:
2121
- id: clang-format
2222
types_or: [c++, c, cuda]

core/util/trt_util.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
164164
// Acceptable range for pos is [-d.nbDims - 1, d.nbDims]
165165
TORCHTRT_ASSERT(
166166
pos >= (-d.nbDims - 1) && pos <= d.nbDims,
167-
"ERROR: Index to unsqueeze is out of bounds. " << "Expected value in range [" << (-d.nbDims - 1) << ", "
168-
<< d.nbDims << "], but got " << pos);
167+
"ERROR: Index to unsqueeze is out of bounds. "
168+
<< "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos);
169169

170170
// Unsqueeze with negative dimensions creates a new dimension at that index
171171
pos = (pos < 0) ? (pos + d.nbDims + 1) : pos;

py/torch_tensorrt/dynamo/_compiler.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import warnings
56
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
67

78
import torch
@@ -22,7 +23,7 @@
2223
UnsupportedOperatorException,
2324
convert_module,
2425
interpret_module_to_result,
25-
repair_long_or_double_inputs,
26+
repair_double_inputs,
2627
)
2728
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2829
DYNAMO_CONVERTERS as CONVERTERS,
@@ -58,7 +59,7 @@ def compile(
5859
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
5960
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
6061
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
61-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
62+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
6263
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
6364
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
6465
torch_executed_ops: Optional[Collection[Target]] = None,
@@ -74,7 +75,7 @@ def compile(
7475
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
7576
**kwargs: Any,
7677
) -> torch.fx.GraphModule:
77-
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
78+
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
7879
7980
Takes a existing TorchScript module and a set of settings to configure the compiler
8081
and will convert methods to JIT Graphs which call equivalent TensorRT engines
@@ -115,7 +116,7 @@ def compile(
115116
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
116117
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
117118
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
118-
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
119+
truncate_double (bool): Truncate weights provided in double (float64) to float32
119120
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
120121
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
121122
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
@@ -138,6 +139,19 @@ def compile(
138139
if debug:
139140
set_log_level(logger.parent, logging.DEBUG)
140141

142+
if "truncate_long_and_double" in kwargs.keys():
143+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
144+
raise ValueError(
145+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
146+
)
147+
else:
148+
truncate_double = kwargs["truncate_long_and_double"]
149+
warnings.warn(
150+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
141155
engine_capability = EngineCapability._from(engine_capability)
142156

143157
if torch_executed_modules is not None and torch_executed_modules:
@@ -185,7 +199,7 @@ def compile(
185199
"version_compatible": version_compatible,
186200
"optimization_level": optimization_level,
187201
"use_python_runtime": use_python_runtime,
188-
"truncate_long_and_double": truncate_long_and_double,
202+
"truncate_double": truncate_double,
189203
"use_fast_partitioner": use_fast_partitioner,
190204
"num_avg_timing_iters": num_avg_timing_iters,
191205
"enable_experimental_decompositions": enable_experimental_decompositions,
@@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
349363

350364
assert submodule_inputs is not None
351365
# Handle long/double inputs if requested by the user
352-
if settings.truncate_long_and_double:
353-
submodule_inputs = repair_long_or_double_inputs(
366+
if settings.truncate_double:
367+
submodule_inputs = repair_double_inputs(
354368
partitioned_module,
355369
submodule,
356370
submodule_inputs,
@@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
423437

424438
def convert_module_to_trt_engine(
425439
exported_program: ExportedProgram,
426-
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
440+
inputs: Tuple[Any, ...],
441+
*,
427442
enabled_precisions: (
428443
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
429444
) = _defaults.ENABLED_PRECISIONS,
@@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
436451
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
437452
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
438453
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
439-
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
454+
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
440455
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
441456
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
442457
device: Device = Device._current_device(),
@@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
451466
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
452467
calibrator: object = None,
453468
allow_shape_tensors: bool = False,
469+
**kwargs: Any,
454470
) -> bytes:
455471
"""Convert an ExportedProgram to a serialized TensorRT engine
456472
@@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
488504
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
489505
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
490506
argument as None
491-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
507+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
492508
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
493509
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
494510
or only a selected subset of them
@@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
512528
if debug:
513529
set_log_level(logger.parent, logging.DEBUG)
514530

531+
if "truncate_long_and_double" in kwargs.keys():
532+
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
533+
raise ValueError(
534+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
535+
)
536+
else:
537+
truncate_double = kwargs["truncate_long_and_double"]
538+
warnings.warn(
539+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
540+
DeprecationWarning,
541+
stacklevel=2,
542+
)
543+
515544
input_list = list(inputs) if inputs is not None else []
516545
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
517546
# Prepare torch_trt inputs
@@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
531560
"version_compatible": version_compatible,
532561
"optimization_level": optimization_level,
533562
"use_python_runtime": use_python_runtime,
534-
"truncate_long_and_double": truncate_long_and_double,
563+
"truncate_double": truncate_double,
535564
"use_fast_partitioner": use_fast_partitioner,
536565
"enable_experimental_decompositions": enable_experimental_decompositions,
537566
"device": device,

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
VERSION_COMPATIBLE = False
1919
OPTIMIZATION_LEVEL = None
2020
SPARSE_WEIGHTS = False
21-
TRUNCATE_LONG_AND_DOUBLE = False
21+
TRUNCATE_DOUBLE = False
2222
USE_PYTHON_RUNTIME = False
2323
USE_FAST_PARTITIONER = True
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False

py/torch_tensorrt/dynamo/_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
REFIT,
2424
REQUIRE_FULL_COMPILATION,
2525
SPARSE_WEIGHTS,
26-
TRUNCATE_LONG_AND_DOUBLE,
26+
TRUNCATE_DOUBLE,
2727
USE_FAST_PARTITIONER,
2828
USE_PYTHON_RUNTIME,
2929
VERSION_COMPATIBLE,
@@ -50,7 +50,7 @@ class CompilationSettings:
5050
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
5151
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
5252
argument as None
53-
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
53+
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
5454
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
5555
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
5656
or only a selected subset of them
@@ -81,7 +81,7 @@ class CompilationSettings:
8181
version_compatible: bool = VERSION_COMPATIBLE
8282
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
8383
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
84-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
84+
truncate_double: bool = TRUNCATE_DOUBLE
8585
use_fast_partitioner: bool = USE_FAST_PARTITIONER
8686
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
8787
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/conversion/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403
6-
from .truncate_long_and_double import repair_long_or_double_inputs
6+
from .truncate_double import repair_double_inputs

py/torch_tensorrt/dynamo/conversion/_conversion.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from typing import List, Sequence
66

7-
import tensorrt as trt
87
import torch
98
from torch_tensorrt._Device import Device
109
from torch_tensorrt._enums import dtype
@@ -18,14 +17,16 @@
1817
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1918
from torch_tensorrt.dynamo.utils import get_torch_inputs
2019

20+
import tensorrt as trt
21+
2122
logger = logging.getLogger(__name__)
2223

2324

2425
def infer_module_output_dtypes(
2526
module: torch.fx.GraphModule,
2627
inputs: Sequence[Input],
2728
device: Device,
28-
truncate_long_and_double: bool = False,
29+
truncate_double: bool = False,
2930
) -> List[dtype]:
3031
torch_inputs = get_torch_inputs(inputs, device)
3132
module = module.to(device.to(torch.device))
@@ -48,10 +49,8 @@ def infer_module_output_dtypes(
4849
else:
4950
output_ = torch.tensor(output)
5051

51-
if truncate_long_and_double and output_.dtype == dtype.float64:
52+
if truncate_double and output_.dtype == dtype.float64:
5253
output_dtypes.append(dtype.float32)
53-
elif truncate_long_and_double and output_.dtype == dtype.int64:
54-
output_dtypes.append(dtype.int32)
5554
else:
5655
output_dtypes.append(dtype._from(output_.dtype))
5756

@@ -75,7 +74,7 @@ def interpret_module_to_result(
7574
module,
7675
inputs,
7776
settings.device,
78-
truncate_long_and_double=settings.truncate_long_and_double,
77+
truncate_double=settings.truncate_double,
7978
)
8079

8180
interpreter = TRTInterpreter(

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,11 @@ def get_trt_tensor(
297297
A TensorRT ITensor that represents the given value.
298298
"""
299299
# If the input is 64-bit, cast it to 32-bit for TRT freezing
300-
if (
301-
isinstance(input_val, torch.Tensor)
302-
and ctx.compilation_settings.truncate_long_and_double
303-
):
304-
if input_val.dtype == torch.int64:
305-
input_val = input_val.to(torch.int32)
306-
elif input_val.dtype == torch.float64:
300+
if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double:
301+
if input_val.dtype == torch.float64:
307302
input_val = input_val.to(torch.float32)
308-
elif (
309-
isinstance(input_val, np.ndarray)
310-
and ctx.compilation_settings.truncate_long_and_double
311-
):
312-
if input_val.dtype == np.int64:
313-
input_val = input_val.astype(np.int32)
314-
elif input_val.dtype == np.float64:
303+
elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double:
304+
if input_val.dtype == np.float64:
315305
input_val = input_val.astype(np.float32)
316306

317307
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,22 @@
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
1111
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
12-
from torch_tensorrt.fx.types import TRTTensor
12+
13+
import tensorrt as trt
1314

1415

1516
def embedding(
1617
ctx: ConversionContext,
1718
target: Target,
1819
source_ir: Optional[SourceIR],
1920
name: str,
20-
input: TRTTensor,
21-
weight: TRTTensor,
21+
input: trt.ITensor,
22+
weight: trt.ITensor,
2223
scale_grad_by_freq: bool,
2324
sparse: bool,
24-
) -> TRTTensor:
25+
) -> trt.ITensor:
2526
indices_tensor = input
2627
embedding_tensor = weight
27-
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
28-
raise RuntimeError(
29-
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
30-
)
3128
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
3229
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
3330
# unsupported parameters
@@ -52,15 +49,15 @@ def embedding_bag(
5249
target: Target,
5350
source_ir: Optional[SourceIR],
5451
name: str,
55-
weight: TRTTensor,
56-
indices: TRTTensor,
52+
weight: trt.ITensor,
53+
indices: trt.ITensor,
5754
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
5855
scale_grad_by_freq: bool,
5956
mode: int,
6057
sparse: bool,
61-
per_sample_weights: Optional[TRTTensor],
58+
per_sample_weights: Optional[trt.ITensor],
6259
include_last_offset: bool,
63-
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
60+
) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]:
6461
"""
6562
This function is for calculating embedding bags.
6663
@@ -143,7 +140,7 @@ def embedding_bag(
143140
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
144141
# is equal to the number of bags + 1. The last element is the size of the input,
145142
# or the ending index position of the last bag (sequence).
146-
offsets[-1] = indices.shape[0]
143+
offsets[-1] = indices.shape[0] # type: ignore[index]
147144

148145
# separately reduce embeddings for different bags
149146
reduced_embed = []

py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py renamed to py/torch_tensorrt/dynamo/conversion/truncate_double.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _repair_64bit_input(
156156
gm.recompile()
157157

158158

159-
def repair_long_or_double_inputs(
159+
def repair_double_inputs(
160160
parent_graph: torch.fx.GraphModule,
161161
submodule: torch.fx.GraphModule,
162162
submodule_inputs: Sequence[Input],

tests/py/dynamo/backend/test_backend_compiler.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def forward(self, x, y):
221221
inputs,
222222
min_block_size=1,
223223
pass_through_build_failures=True,
224-
truncate_long_and_double=True,
224+
truncate_double=True,
225225
debug=True,
226226
)
227227
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -240,16 +240,14 @@ def forward(self, x, y):
240240
def test_int64_input_partial_support(self):
241241
class PartiallySupportedMultiOp(torch.nn.Module):
242242
def forward(self, x, y):
243-
return torch.ops.aten.div.Tensor_mode(
244-
x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None
245-
)
243+
return torch.ops.aten.abs(torch.ops.aten.add.Tensor(x, y))
246244

247245
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
248246
unexpected_ops = {torch.ops.aten.add.Tensor}
249247

250248
inputs = [
251-
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
252-
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
249+
torch.randint(-40, 40, (1, 16, 7, 5), dtype=torch.long).cuda(),
250+
torch.randint(1, 40, (1, 16, 7, 5), dtype=torch.long).cuda(),
253251
]
254252

255253
(
@@ -296,8 +294,9 @@ def forward(self, x, y):
296294
inputs,
297295
min_block_size=1,
298296
pass_through_build_failures=True,
299-
truncate_long_and_double=True,
297+
truncate_double=False,
300298
debug=True,
299+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
301300
)
302301
optimized_model_results = optimized_model(*inputs).detach().cpu()
303302
torch_model_results = fx_graph(*inputs).detach().cpu()

0 commit comments

Comments
 (0)