@@ -25,10 +25,6 @@ def __init__(
25
25
logging_dir : Optional [str ] = None ,
26
26
):
27
27
self .debug_file_dir = tempfile .TemporaryDirectory ().name
28
- if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile ):
29
- _LOGGER .warning (
30
- "Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
31
- )
32
28
33
29
if log_level == "debug" :
34
30
self .log_level = logging .DEBUG
@@ -60,7 +56,7 @@ def __enter__(self) -> None:
60
56
self .rt_level = torch .ops .tensorrt .get_logging_level ()
61
57
dictConfig (self .get_config ())
62
58
63
- if self .log_level == GRAPH_LEVEL :
59
+ if self .capture_fx_graph_before or self . capture_fx_graph_after :
64
60
self .old_pre_passes , self .old_post_passes = (
65
61
ATEN_PRE_LOWERING_PASSES .passes ,
66
62
ATEN_POST_LOWERING_PASSES .passes ,
@@ -93,14 +89,14 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
93
89
94
90
dictConfig (self .get_default_config ())
95
91
torch .ops .tensorrt .set_logging_level (self .rt_level )
96
- if self .log_level == GRAPH_LEVEL and self .capture_fx_graph_after :
92
+ if self .capture_fx_graph_before or self .capture_fx_graph_after :
97
93
ATEN_PRE_LOWERING_PASSES .passes , ATEN_POST_LOWERING_PASSES .passes = (
98
94
self .old_pre_passes ,
99
95
self .old_post_passes ,
100
96
)
101
97
self .debug_file_dir = tempfile .TemporaryDirectory ().name
102
98
103
- def get_config (self ) -> dict [str , Any ]:
99
+ def get_customized_logging_config (self ) -> dict [str , Any ]:
104
100
config = {
105
101
"version" : 1 ,
106
102
"disable_existing_loggers" : False ,
@@ -138,7 +134,7 @@ def get_config(self) -> dict[str, Any]:
138
134
}
139
135
return config
140
136
141
- def get_default_config (self ) -> dict [str , Any ]:
137
+ def get_default_logging_config (self ) -> dict [str , Any ]:
142
138
config = {
143
139
"version" : 1 ,
144
140
"disable_existing_loggers" : False ,
0 commit comments