Skip to content

Commit d3a47c4

Browse files
authored
fix: Address .numpy() issue on fake tensors (#1949)
1 parent 5be3b58 commit d3a47c4

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def compile(
4646
torch_executed_modules=[],
4747
**kwargs,
4848
):
49+
if debug:
50+
logger.setLevel(logging.DEBUG)
4951

5052
logger.warn(
5153
"The Dynamo backend is an experimental feature, for which only the "

py/torch_tensorrt/dynamo/backend/backends.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def aot_torch_tensorrt_aten_backend(
5252
)
5353

5454

55+
@fake_tensor_unsupported
5556
def _pretraced_backend(
5657
gm: torch.fx.GraphModule,
5758
sample_inputs: Sequence[torch.Tensor],
@@ -120,9 +121,7 @@ def _compile_module(
120121
trt_mod = convert_module(
121122
submodule,
122123
submodule_inputs,
123-
debug=settings.debug,
124-
workspace_size=settings.workspace_size,
125-
precision=settings.precision,
124+
settings=settings,
126125
)
127126

128127
# Replace FX Module with TRT Module

py/torch_tensorrt/dynamo/backend/conversion.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,41 @@
22
import torch
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
5+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
56
from torch_tensorrt.fx.fx2trt import (
67
InputTensorSpec,
78
TRTInterpreter,
89
)
9-
from torch_tensorrt.fx.utils import LowerPrecision
1010

1111
import tensorrt as trt
1212

1313

1414
def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
17-
debug: bool = False,
18-
workspace_size: int = 20 << 30,
19-
precision: LowerPrecision = LowerPrecision.FP32,
17+
settings: CompilationSettings = CompilationSettings(),
2018
) -> Union[TRTModuleNext, TRTModule]:
2119
"""Convert an FX module to a TRT module
2220
Args:
2321
module: FX GraphModule to convert
2422
inputs: Sequence of Tensors representing inputs to the module
25-
debug: Whether to print out verbose debugging information
26-
workspace_size: Maximum workspace TRT is allowed to use for the module
27-
precision: Model Layer precision
23+
settings: Compilation settings
2824
Returns:
2925
TRTModule or TRTModuleNext
3026
"""
3127
interp = TRTInterpreter(
3228
module,
3329
InputTensorSpec.from_tensors(inputs),
3430
explicit_batch_dimension=True,
35-
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
31+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3632
)
3733

3834
r = interp.run(
39-
max_workspace_size=workspace_size,
40-
lower_precision=precision,
35+
max_workspace_size=settings.workspace_size,
36+
lower_precision=settings.precision,
4137
profiling_verbosity=(
4238
trt.ProfilingVerbosity.VERBOSE
43-
if debug
39+
if settings.debug
4440
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4541
),
4642
)

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
136136
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
137137
)
138138

139-
logger.debug("\nSupported Nodes:")
139+
# Reformat support messages for debugger to print node overview as a single string
140+
supported_nodes_str = "\nSupported Nodes:\n"
140141
for node_name in self.supported_operators:
141-
logger.debug("-", node_name)
142+
supported_nodes_str += f"- {node_name}\n"
143+
144+
logger.debug(supported_nodes_str)
142145

143146
if len(self.unsupported_operators) != 0:
144-
logger.debug("\nUnsupported or Excluded Nodes:")
147+
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
145148
for node_name in self.unsupported_operators:
146-
logger.debug("-", node_name)
147-
logger.debug("\n")
149+
unsupported_nodes_str += f"- {node_name}\n"
150+
logger.debug(unsupported_nodes_str)
148151
else:
149152
logger.debug("\nAll Nodes Supported\n")
150153

0 commit comments

Comments
 (0)