2
2
3
3
import collections .abc
4
4
import logging
5
+ import os
5
6
import platform
6
7
import warnings
7
8
from typing import Any , Collection , List , Optional , Sequence , Set , Tuple , Union
31
32
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
32
33
DYNAMO_CONVERTERS as CONVERTERS ,
33
34
)
35
+ from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
36
+ from torch_tensorrt .dynamo .debug ._supports_debugger import fn_supports_debugger
34
37
from torch_tensorrt .dynamo .lowering import (
35
38
get_decompositions ,
36
39
post_lowering ,
41
44
get_output_metadata ,
42
45
parse_graph_io ,
43
46
prepare_inputs ,
44
- set_log_level ,
45
47
to_torch_device ,
46
48
to_torch_tensorrt_device ,
47
49
)
@@ -63,7 +65,7 @@ def cross_compile_for_windows(
63
65
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
64
66
] = _defaults .ENABLED_PRECISIONS ,
65
67
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
66
- debug : bool = _defaults . DEBUG ,
68
+ debug : bool = False ,
67
69
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
68
70
workspace_size : int = _defaults .WORKSPACE_SIZE ,
69
71
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -184,7 +186,11 @@ def cross_compile_for_windows(
184
186
)
185
187
186
188
if debug :
187
- set_log_level (logger .parent , logging .DEBUG )
189
+ warnings .warn (
190
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
191
+ DeprecationWarning ,
192
+ stacklevel = 2 ,
193
+ )
188
194
189
195
if "truncate_long_and_double" in kwargs .keys ():
190
196
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
@@ -295,7 +301,6 @@ def cross_compile_for_windows(
295
301
"enabled_precisions" : (
296
302
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
297
303
),
298
- "debug" : debug ,
299
304
"device" : device ,
300
305
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
301
306
"workspace_size" : workspace_size ,
@@ -386,7 +391,7 @@ def compile(
386
391
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
387
392
] = _defaults .ENABLED_PRECISIONS ,
388
393
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
389
- debug : bool = _defaults . DEBUG ,
394
+ debug : bool = False ,
390
395
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
391
396
workspace_size : int = _defaults .WORKSPACE_SIZE ,
392
397
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -503,6 +508,13 @@ def compile(
503
508
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
504
509
"""
505
510
511
+ if debug :
512
+ warnings .warn (
513
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality" ,
514
+ DeprecationWarning ,
515
+ stacklevel = 2 ,
516
+ )
517
+
506
518
if "truncate_long_and_double" in kwargs .keys ():
507
519
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
508
520
raise ValueError (
@@ -633,7 +645,6 @@ def compile(
633
645
"enabled_precisions" : (
634
646
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
635
647
),
636
- "debug" : debug ,
637
648
"device" : device ,
638
649
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
639
650
"workspace_size" : workspace_size ,
@@ -694,12 +705,15 @@ def compile(
694
705
return trt_gm
695
706
696
707
708
+ @fn_supports_debugger
697
709
def compile_module (
698
710
gm : torch .fx .GraphModule ,
699
711
sample_arg_inputs : Sequence [Input ],
700
712
sample_kwarg_inputs : Optional [dict [Any , Any ]] = None ,
701
713
settings : CompilationSettings = CompilationSettings (),
702
714
engine_cache : Optional [BaseEngineCache ] = None ,
715
+ * ,
716
+ _debugger_settings : Optional [DebuggerConfig ] = None ,
703
717
) -> torch .fx .GraphModule :
704
718
"""Compile a traced FX module
705
719
@@ -900,6 +914,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
900
914
901
915
trt_modules [name ] = trt_module
902
916
917
+ if _debugger_settings :
918
+
919
+ if _debugger_settings .save_engine_profile :
920
+ if settings .use_python_runtime :
921
+ if _debugger_settings .profile_format == "trex" :
922
+ logger .warning (
923
+ "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
924
+ )
925
+ trt_module .enable_profiling ()
926
+ else :
927
+ path = os .path .join (
928
+ _debugger_settings .logging_dir , "engine_visualization"
929
+ )
930
+ os .makedirs (path , exist_ok = True )
931
+ trt_module .enable_profiling (
932
+ profiling_results_dir = path ,
933
+ profile_format = _debugger_settings .profile_format ,
934
+ )
935
+
936
+ if _debugger_settings .save_layer_info :
937
+ with open (
938
+ os .path .join (
939
+ _debugger_settings .logging_dir , "engine_layer_info.json"
940
+ ),
941
+ "w" ,
942
+ ) as f :
943
+ f .write (trt_module .get_layer_info ())
944
+
903
945
# Parse the graph I/O and store it in dryrun tracker
904
946
parse_graph_io (gm , dryrun_tracker )
905
947
@@ -927,7 +969,7 @@ def convert_exported_program_to_serialized_trt_engine(
927
969
enabled_precisions : (
928
970
Set [torch .dtype | dtype ] | Tuple [torch .dtype | dtype ]
929
971
) = _defaults .ENABLED_PRECISIONS ,
930
- debug : bool = _defaults . DEBUG ,
972
+ debug : bool = False ,
931
973
assume_dynamic_shape_support : bool = _defaults .ASSUME_DYNAMIC_SHAPE_SUPPORT ,
932
974
workspace_size : int = _defaults .WORKSPACE_SIZE ,
933
975
min_block_size : int = _defaults .MIN_BLOCK_SIZE ,
@@ -1029,7 +1071,11 @@ def convert_exported_program_to_serialized_trt_engine(
1029
1071
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
1030
1072
"""
1031
1073
if debug :
1032
- set_log_level (logger .parent , logging .DEBUG )
1074
+ warnings .warn (
1075
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
1076
+ DeprecationWarning ,
1077
+ stacklevel = 2 ,
1078
+ )
1033
1079
1034
1080
if "truncate_long_and_double" in kwargs .keys ():
1035
1081
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
0 commit comments