1
1
import logging
2
2
import math
3
3
from dataclasses import dataclass , field
4
- from typing import List , Tuple
4
+ from typing import Any , Dict , List
5
5
6
- import torch
6
+ from torch_tensorrt . dynamo . _settings import CompilationSettings
7
7
8
8
logger = logging .getLogger (__name__ )
9
9
@@ -15,18 +15,18 @@ class PerSubgraphData:
15
15
Args:
16
16
subgraph_name (str): Name of the subgraph in the GraphModule
17
17
subgraph_op_count (int): Number of operations in the subgraph
18
- subgraph_input_shapes (List[Tuple[int, ...]] ): Shapes of input Tensors of the subgraph
19
- subgraph_input_dtypes (List[torch.device] ): Input data types of the subgraph
20
- subgraph_output_shapes (List[Tuple[int, ...]] ): Shapes of output Tensors of the subgraph
21
- subgraph_output_dtypes (List[torch.device] ): Output data types of the subgraph
18
+ subgraph_input_shapes (Any ): Shapes of input Tensors of the subgraph
19
+ subgraph_input_dtypes (Any ): Input data types of the subgraph
20
+ subgraph_output_shapes (Any ): Shapes of output Tensors of the subgraph
21
+ subgraph_output_dtypes (Any ): Output data types of the subgraph
22
22
"""
23
23
24
24
subgraph_name : str = ""
25
25
subgraph_op_count : int = 0
26
- subgraph_input_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
27
- subgraph_input_dtypes : List [ torch . device ] = field (default_factory = list )
28
- subgraph_output_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
29
- subgraph_output_dtypes : List [ torch . device ] = field (default_factory = list )
26
+ subgraph_input_shapes : Any = field (default_factory = list )
27
+ subgraph_input_dtypes : Any = field (default_factory = list )
28
+ subgraph_output_shapes : Any = field (default_factory = list )
29
+ subgraph_output_dtypes : Any = field (default_factory = list )
30
30
31
31
32
32
@dataclass
@@ -36,95 +36,86 @@ class DryRunTracker:
36
36
Args:
37
37
total_ops_in_graph (int): Total number of operators in graph
38
38
supported_ops_in_graph (int): Number of supported operators in graph
39
- graph_input_shapes (List[Tuple[int, ...]] ): Shapes of input Tensors of the graph
40
- graph_input_dtypes (List[torch.device] ): Input data types of the graph
41
- graph_output_shapes (List[Tuple[int, ...]] ): Shapes of output Tensors of the graph
42
- graph_output_dtypes (List[torch.device] ): Output data types of the graph
39
+ graph_input_shapes (Any ): Shapes of input Tensors of the graph
40
+ graph_input_dtypes (Any ): Input data types of the graph
41
+ graph_output_shapes (Any ): Shapes of output Tensors of the graph
42
+ graph_output_dtypes (Any ): Output data types of the graph
43
43
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
44
44
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45
- truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
45
+ compilation_settings (CompilationSettings): User Compilation Settings
46
+ unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
46
47
"""
47
48
48
49
total_ops_in_graph : int = 0
49
50
supported_ops_in_graph : int = 0
50
- graph_input_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
51
- graph_input_dtypes : List [ torch . device ] = field (default_factory = list )
52
- graph_output_shapes : List [ Tuple [ int , ...]] = field (default_factory = list )
53
- graph_output_dtypes : List [ torch . device ] = field (default_factory = list )
51
+ graph_input_shapes : Any = field (default_factory = list )
52
+ graph_input_dtypes : Any = field (default_factory = list )
53
+ graph_output_shapes : Any = field (default_factory = list )
54
+ graph_output_dtypes : Any = field (default_factory = list )
54
55
per_subgraph_data : List [PerSubgraphData ] = field (default_factory = list )
55
56
tensorrt_graph_count : int = 0
56
- truncated_long_and_double : bool = False
57
+ compilation_settings : CompilationSettings = field (
58
+ default_factory = CompilationSettings
59
+ )
60
+ unsupported_ops : Dict [str , int ] = field (default_factory = dict )
57
61
58
62
59
63
def dryrun_stats_display (dryrun_tracker : DryRunTracker , dryrun_enabled : bool ) -> None :
60
- """Displays statistics about the dryrun either to debug logs or info logs"""
61
- # If user specified "dryrun=True", print to info logs, else debug
62
- if dryrun_enabled :
63
- dryrun_logger = logger .info
64
- else :
65
- dryrun_logger = logger .debug
66
-
64
+ """Displays statistics about the dryrun either to debug logs or stdout"""
67
65
formatted_stats = "\n "
68
66
69
67
# Print overall stats about the graph, operator counts, etc.
70
- formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n "
68
+ formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n \n "
71
69
formatted_stats += (
72
70
f"The graph consists of { dryrun_tracker .total_ops_in_graph } Total Operators, "
73
71
f"of which { dryrun_tracker .supported_ops_in_graph } operators are supported, "
74
- f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n "
75
- )
76
- formatted_stats += f"Long and double inputs were { '' if dryrun_tracker .truncated_long_and_double else 'not' } truncated (truncate_long_and_double={ dryrun_tracker .truncated_long_and_double } )\n "
77
- formatted_stats += (
78
- f"{ dryrun_tracker .tensorrt_graph_count } TRT Engine(s) were generated\n "
72
+ f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n \n "
79
73
)
74
+ formatted_stats += f"The following ops are currently unsupported and set to run in Torch: { dryrun_tracker .unsupported_ops } \n \n "
75
+ formatted_stats += f"Compiled with: { dryrun_tracker .compilation_settings } \n \n "
80
76
81
77
assert len (dryrun_tracker .per_subgraph_data ) == dryrun_tracker .tensorrt_graph_count
82
78
83
79
# Print schematic of the graph structure, as in:
84
80
#
85
- # Inputs: [Tensor: (1, 3, 224, 224)@float32]
81
+ # Inputs: List [Tensor: (1, 3, 224, 224)@float32]
86
82
# ...
87
- # TRT Engine #1: _run_on_acc_0
88
- # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89
- # Number of Operators in Engine: 1
90
- # Engine Outputs: [ Tensor: (1, 64, 112, 112)@float32]
83
+ # TRT Engine #1 - Submodule name : _run_on_acc_0
84
+ # Engine Inputs: List [Tensor: (1, 3, 224, 224)@float32]
85
+ # Number of Operators in Engine: 1
86
+ # Engine Outputs: Tensor: (1, 64, 112, 112)@float32
91
87
# ...
92
- # Outputs: [Tensor: (1, 1000)@float32]
88
+ # Outputs: List [Tensor: (1, 1000)@float32]
93
89
#
94
90
formatted_stats += " " * 2 + "Graph Structure:\n \n "
95
91
formatted_stats += (
96
92
" " * 3
97
- + f"Inputs: [ { input_formatter (dryrun_tracker .graph_input_shapes , dryrun_tracker .graph_input_dtypes )} ] \n "
93
+ + f"Inputs: { input_formatter (dryrun_tracker .graph_input_shapes , dryrun_tracker .graph_input_dtypes )} \n "
98
94
)
99
95
100
96
for i , trt_subgraph_data in enumerate (dryrun_tracker .per_subgraph_data ):
101
- assert len (trt_subgraph_data .subgraph_input_dtypes ) == len (
102
- trt_subgraph_data .subgraph_input_shapes
103
- )
104
- assert len (trt_subgraph_data .subgraph_output_dtypes ) == len (
105
- trt_subgraph_data .subgraph_output_shapes
106
- )
107
97
formatted_stats += " " * 4 + "...\n "
108
98
formatted_stats += (
109
- " " * 4 + f"TRT Engine #{ i + 1 } : { trt_subgraph_data .subgraph_name } \n "
99
+ " " * 4
100
+ + f"TRT Engine #{ i + 1 } - Submodule name: { trt_subgraph_data .subgraph_name } \n "
110
101
)
111
102
formatted_stats += (
112
103
" " * 5
113
- + f"Engine Inputs: [ { input_formatter (trt_subgraph_data .subgraph_input_shapes , trt_subgraph_data .subgraph_input_dtypes )} ] \n "
104
+ + f"Engine Inputs: { input_formatter (trt_subgraph_data .subgraph_input_shapes , trt_subgraph_data .subgraph_input_dtypes )} \n "
114
105
)
115
106
formatted_stats += (
116
107
" " * 5
117
108
+ f"Number of Operators in Engine: { trt_subgraph_data .subgraph_op_count } \n "
118
109
)
119
110
formatted_stats += (
120
111
" " * 5
121
- + f"Engine Outputs: [ { input_formatter (trt_subgraph_data .subgraph_output_shapes , trt_subgraph_data .subgraph_output_dtypes )} ] \n "
112
+ + f"Engine Outputs: { input_formatter (trt_subgraph_data .subgraph_output_shapes , trt_subgraph_data .subgraph_output_dtypes )} \n "
122
113
)
123
114
124
115
formatted_stats += " " * 4 + "...\n "
125
116
formatted_stats += (
126
117
" " * 3
127
- + f"Outputs: [ { input_formatter (dryrun_tracker .graph_output_shapes , dryrun_tracker .graph_output_dtypes )} ] \n "
118
+ + f"Outputs: { input_formatter (dryrun_tracker .graph_output_shapes , dryrun_tracker .graph_output_dtypes )} \n "
128
119
)
129
120
130
121
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
167
158
+ " " * 3
168
159
+ "- For minimal graph segmentation, select min_block_size="
169
160
+ f"{ most_ops_in_an_engine } which would generate "
170
- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= most_ops_in_an_engine ])} TRT engines "
161
+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= most_ops_in_an_engine ])} TRT engine(s) "
171
162
)
172
163
if math .ceil (avg_ops_per_engine ) != most_ops_in_an_engine :
173
164
formatted_stats += (
174
165
"\n "
175
166
+ " " * 3
176
167
+ "- For moderate graph segmentation, select min_block_size="
177
168
+ f"{ math .ceil (avg_ops_per_engine )} which would generate "
178
- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= math .ceil (avg_ops_per_engine )])} TRT engines "
169
+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= math .ceil (avg_ops_per_engine )])} TRT engine(s) "
179
170
)
180
171
181
172
formatted_stats += (
182
173
"\n "
183
174
+ " " * 3
184
175
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185
176
+ f"{ min_ops_in_an_engine } which generates "
186
- + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= min_ops_in_an_engine ])} TRT engines "
177
+ + f"{ len ([1 for trt_subgraph in dryrun_tracker .per_subgraph_data if trt_subgraph .subgraph_op_count >= min_ops_in_an_engine ])} TRT engine(s) "
187
178
)
188
179
else :
189
180
formatted_stats += (
@@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
192
183
+ "Aggregate stats not available since no TRT Engines were generated."
193
184
)
194
185
195
- dryrun_logger (formatted_stats )
186
+ # If user specified "dryrun=True", print to stdout, else debug
187
+ if dryrun_enabled :
188
+ print (formatted_stats )
189
+ else :
190
+ logger .debug (formatted_stats )
196
191
197
192
198
- def input_formatter (shapes : List [ Tuple [ int , ...]], dtypes : List [ torch . dtype ] ) -> str :
193
+ def input_formatter (shapes : Any , dtypes : Any ) -> str :
199
194
"""Format shapes and dtypes of input Tensors into a readable string"""
200
- formatted_str = ", "
201
195
202
- for shape , dtype in zip (shapes , dtypes ):
203
- formatted_str += f"Tensor: { shape } @{ str (dtype )[6 :]} , "
196
+ def input_formatter_helper (shapes : Any , dtypes : Any ) -> str :
197
+ """Helper for input formatter"""
198
+ # Base case - single shape, single dtype
199
+ if isinstance (shapes , tuple ) and all (isinstance (elt , int ) for elt in shapes ):
200
+ return f"Tensor: { shapes } @{ str (dtypes )[6 :]} , "
201
+
202
+ # Shapes is a sequence
203
+ elif isinstance (shapes , (list , tuple )):
204
+ formatted_str = "List[" if isinstance (shapes , list ) else "Tuple("
205
+ for shape , dtype in zip (shapes , dtypes ):
206
+ formatted_str += input_formatter_helper (shape , dtype )
207
+ formatted_str = formatted_str [:- 2 ] + (
208
+ "], " if isinstance (shapes , list ) else "), "
209
+ )
210
+ return formatted_str
211
+
212
+ # Shapes is a dictionary
213
+ elif isinstance (shapes , dict ):
214
+ formatted_str = "Dict{"
215
+
216
+ for key , shape in shapes .items ():
217
+ formatted_str += input_formatter_helper (shape , dtypes [key ])
218
+
219
+ formatted_str = formatted_str [:- 2 ] + "}, "
220
+ return formatted_str
221
+
222
+ else :
223
+ raise ValueError (
224
+ f"Invalid input type { type (shapes )} encountered in parse_complex_tensor_structs parsing."
225
+ )
204
226
205
- return formatted_str [ 2 :- 2 ]
227
+ return input_formatter_helper ( shapes , dtypes )[ :- 2 ]
0 commit comments