Skip to content

Commit 27ddf38

Browse files
authored
Changed the debug setting (#3551)
1 parent 3d8438f commit 27ddf38

File tree

11 files changed

+183
-36
lines changed

11 files changed

+183
-36
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17-
from ._Debugger import Debugger
1817
from ._exporter import export
1918
from ._refit import refit_module_weights
2019
from ._settings import CompilationSettings
2120
from ._SourceIR import SourceIR
2221
from ._tracer import trace
22+
from .debug._Debugger import Debugger

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import os
56
import platform
67
import warnings
78
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -31,6 +32,8 @@
3132
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3233
DYNAMO_CONVERTERS as CONVERTERS,
3334
)
35+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
36+
from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger
3437
from torch_tensorrt.dynamo.lowering import (
3538
get_decompositions,
3639
post_lowering,
@@ -41,7 +44,6 @@
4144
get_output_metadata,
4245
parse_graph_io,
4346
prepare_inputs,
44-
set_log_level,
4547
to_torch_device,
4648
to_torch_tensorrt_device,
4749
)
@@ -63,7 +65,7 @@ def cross_compile_for_windows(
6365
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6466
] = _defaults.ENABLED_PRECISIONS,
6567
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
66-
debug: bool = _defaults.DEBUG,
68+
debug: bool = False,
6769
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6870
workspace_size: int = _defaults.WORKSPACE_SIZE,
6971
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -184,7 +186,11 @@ def cross_compile_for_windows(
184186
)
185187

186188
if debug:
187-
set_log_level(logger.parent, logging.DEBUG)
189+
warnings.warn(
190+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
191+
DeprecationWarning,
192+
stacklevel=2,
193+
)
188194

189195
if "truncate_long_and_double" in kwargs.keys():
190196
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -295,7 +301,6 @@ def cross_compile_for_windows(
295301
"enabled_precisions": (
296302
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
297303
),
298-
"debug": debug,
299304
"device": device,
300305
"assume_dynamic_shape_support": assume_dynamic_shape_support,
301306
"workspace_size": workspace_size,
@@ -386,7 +391,7 @@ def compile(
386391
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
387392
] = _defaults.ENABLED_PRECISIONS,
388393
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
389-
debug: bool = _defaults.DEBUG,
394+
debug: bool = False,
390395
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
391396
workspace_size: int = _defaults.WORKSPACE_SIZE,
392397
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -503,6 +508,13 @@ def compile(
503508
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
504509
"""
505510

511+
if debug:
512+
warnings.warn(
513+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality",
514+
DeprecationWarning,
515+
stacklevel=2,
516+
)
517+
506518
if "truncate_long_and_double" in kwargs.keys():
507519
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
508520
raise ValueError(
@@ -633,7 +645,6 @@ def compile(
633645
"enabled_precisions": (
634646
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
635647
),
636-
"debug": debug,
637648
"device": device,
638649
"assume_dynamic_shape_support": assume_dynamic_shape_support,
639650
"workspace_size": workspace_size,
@@ -694,12 +705,15 @@ def compile(
694705
return trt_gm
695706

696707

708+
@fn_supports_debugger
697709
def compile_module(
698710
gm: torch.fx.GraphModule,
699711
sample_arg_inputs: Sequence[Input],
700712
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
701713
settings: CompilationSettings = CompilationSettings(),
702714
engine_cache: Optional[BaseEngineCache] = None,
715+
*,
716+
_debugger_settings: Optional[DebuggerConfig] = None,
703717
) -> torch.fx.GraphModule:
704718
"""Compile a traced FX module
705719
@@ -900,6 +914,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
900914

901915
trt_modules[name] = trt_module
902916

917+
if _debugger_settings:
918+
919+
if _debugger_settings.save_engine_profile:
920+
if settings.use_python_runtime:
921+
if _debugger_settings.profile_format == "trex":
922+
logger.warning(
923+
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
924+
)
925+
trt_module.enable_profiling()
926+
else:
927+
path = os.path.join(
928+
_debugger_settings.logging_dir, "engine_visualization"
929+
)
930+
os.makedirs(path, exist_ok=True)
931+
trt_module.enable_profiling(
932+
profiling_results_dir=path,
933+
profile_format=_debugger_settings.profile_format,
934+
)
935+
936+
if _debugger_settings.save_layer_info:
937+
with open(
938+
os.path.join(
939+
_debugger_settings.logging_dir, "engine_layer_info.json"
940+
),
941+
"w",
942+
) as f:
943+
f.write(trt_module.get_layer_info())
944+
903945
# Parse the graph I/O and store it in dryrun tracker
904946
parse_graph_io(gm, dryrun_tracker)
905947

@@ -927,7 +969,7 @@ def convert_exported_program_to_serialized_trt_engine(
927969
enabled_precisions: (
928970
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
929971
) = _defaults.ENABLED_PRECISIONS,
930-
debug: bool = _defaults.DEBUG,
972+
debug: bool = False,
931973
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
932974
workspace_size: int = _defaults.WORKSPACE_SIZE,
933975
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1029,7 +1071,11 @@ def convert_exported_program_to_serialized_trt_engine(
10291071
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10301072
"""
10311073
if debug:
1032-
set_log_level(logger.parent, logging.DEBUG)
1074+
warnings.warn(
1075+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
1076+
DeprecationWarning,
1077+
stacklevel=2,
1078+
)
10331079

10341080
if "truncate_long_and_double" in kwargs.keys():
10351081
if truncate_double is not _defaults.TRUNCATE_DOUBLE:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch_tensorrt._enums import EngineCapability, dtype
77

88
ENABLED_PRECISIONS = {dtype.f32}
9-
DEBUG = False
109
DEVICE = None
1110
DISABLE_TF32 = False
1211
ASSUME_DYNAMIC_SHAPE_SUPPORT = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10-
DEBUG,
1110
DISABLE_TF32,
1211
DLA_GLOBAL_DRAM_SIZE,
1312
DLA_LOCAL_DRAM_SIZE,
@@ -100,7 +99,6 @@ class CompilationSettings:
10099
"""
101100

102101
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
103-
debug: bool = DEBUG
104102
workspace_size: int = WORKSPACE_SIZE
105103
min_block_size: int = MIN_BLOCK_SIZE
106104
torch_executed_ops: Collection[Target] = field(default_factory=set)

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
from torch.export import Dim, export
99
from torch_tensorrt._Input import Input
10-
from torch_tensorrt.dynamo._defaults import DEBUG, default_device
11-
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
10+
from torch_tensorrt.dynamo._defaults import default_device
11+
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,10 +70,6 @@ def trace(
7070
if kwarg_inputs is None:
7171
kwarg_inputs = {}
7272

73-
debug = kwargs.get("debug", DEBUG)
74-
if debug:
75-
set_log_level(logger.parent, logging.DEBUG)
76-
7773
device = to_torch_device(kwargs.get("device", default_device()))
7874
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
7975
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
get_trt_tensor,
4646
to_torch,
4747
)
48+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
49+
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
4850
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
4951
from torch_tensorrt.fx.observer import Observer
5052
from torch_tensorrt.logging import TRT_LOGGER
@@ -70,6 +72,7 @@ class TRTInterpreterResult(NamedTuple):
7072
requires_output_allocator: bool
7173

7274

75+
@cls_supports_debugger
7376
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
7477
def __init__(
7578
self,
@@ -78,12 +81,14 @@ def __init__(
7881
output_dtypes: Optional[Sequence[dtype]] = None,
7982
compilation_settings: CompilationSettings = CompilationSettings(),
8083
engine_cache: Optional[BaseEngineCache] = None,
84+
*,
85+
_debugger_settings: Optional[DebuggerConfig] = None,
8186
):
8287
super().__init__(module)
8388

8489
self.logger = TRT_LOGGER
8590
self.builder = trt.Builder(self.logger)
86-
91+
self._debugger_settings = _debugger_settings
8792
flag = 0
8893
if compilation_settings.use_explicit_typing:
8994
STRONGLY_TYPED = 1 << (int)(
@@ -204,7 +209,7 @@ def _populate_trt_builder_config(
204209
) -> trt.IBuilderConfig:
205210
builder_config = self.builder.create_builder_config()
206211

207-
if self.compilation_settings.debug:
212+
if self._debugger_settings and self._debugger_settings.engine_builder_monitor:
208213
builder_config.progress_monitor = TRTBulderMonitor()
209214

210215
if self.compilation_settings.workspace_size != 0:
@@ -215,7 +220,8 @@ def _populate_trt_builder_config(
215220
if version.parse(trt.__version__) >= version.parse("8.2"):
216221
builder_config.profiling_verbosity = (
217222
trt.ProfilingVerbosity.DETAILED
218-
if self.compilation_settings.debug
223+
if self._debugger_settings
224+
and self._debugger_settings.save_engine_profile
219225
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
220226
)
221227

0 commit comments

Comments
 (0)