Skip to content

Commit 1feccb0

Browse files
committed
feat: Add dryrun feature to Dynamo paths
- Enables building of TRT engines with "dryrun" capabilities, meaning all of the phases except conversion are run and verbose logs of the graph structure and composition are printed for the user - Improves general-purpose debug logging by printing dryrun stats to the debug logs regardless of option specification - Provides intuitive schematic of the graph engines, inputs, and code path through the course of the graph
1 parent 4985c70 commit 1feccb0

File tree

4 files changed

+294
-9
lines changed

4 files changed

+294
-9
lines changed
+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import logging
2+
import math
3+
from dataclasses import dataclass, field
4+
from typing import List, Tuple
5+
6+
import torch
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@dataclass
12+
class PerSubgraphData:
13+
"""Class to track data on a per-subgraph level
14+
15+
Args:
16+
subgraph_name (str): Name of the subgraph in the GraphModule
17+
subgraph_op_count (int): Number of operations in the subgraph
18+
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
19+
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
20+
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
21+
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
22+
"""
23+
24+
subgraph_name: str = ""
25+
subgraph_op_count: int = 0
26+
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
27+
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
28+
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
29+
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
30+
31+
32+
@dataclass
33+
class DryRunTracker:
34+
"""Class to track data on a graph-wide level
35+
36+
Args:
37+
total_ops_in_graph (int): Total number of operators in graph
38+
supported_ops_in_graph (int): Number of supported operators in graph
39+
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
40+
graph_input_dtypes (List[torch.device]): Input data types of the graph
41+
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
42+
graph_output_dtypes (List[torch.device]): Output data types of the graph
43+
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
44+
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45+
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
46+
"""
47+
48+
total_ops_in_graph: int = 0
49+
supported_ops_in_graph: int = 0
50+
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
51+
graph_input_dtypes: List[torch.device] = field(default_factory=list)
52+
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
53+
graph_output_dtypes: List[torch.device] = field(default_factory=list)
54+
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
55+
tensorrt_graph_count: int = 0
56+
truncated_long_and_double: bool = False
57+
58+
59+
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
60+
"""Displays statistics about the dryrun either to debug logs or info logs"""
61+
# If user specified "dryrun=True", print to info logs, else debug
62+
if dryrun_enabled:
63+
dryrun_logger = logger.info
64+
else:
65+
dryrun_logger = logger.debug
66+
67+
formatted_stats = "\n"
68+
69+
# Print overall stats about the graph, operator counts, etc.
70+
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
71+
formatted_stats += (
72+
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
73+
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
74+
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
75+
)
76+
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
77+
formatted_stats += (
78+
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
79+
)
80+
81+
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
82+
83+
# Print schematic of the graph structure, as in:
84+
#
85+
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
86+
# ...
87+
# TRT Engine #1: _run_on_acc_0
88+
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89+
# Number of Operators in Engine: 1
90+
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
91+
# ...
92+
# Outputs: [Tensor: (1, 1000)@float32]
93+
#
94+
formatted_stats += " " * 2 + "Graph Structure:\n\n"
95+
formatted_stats += (
96+
" " * 3
97+
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
98+
)
99+
100+
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
101+
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
102+
trt_subgraph_data.subgraph_input_shapes
103+
)
104+
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
105+
trt_subgraph_data.subgraph_output_shapes
106+
)
107+
formatted_stats += " " * 4 + "...\n"
108+
formatted_stats += (
109+
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
110+
)
111+
formatted_stats += (
112+
" " * 5
113+
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
114+
)
115+
formatted_stats += (
116+
" " * 5
117+
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
118+
)
119+
formatted_stats += (
120+
" " * 5
121+
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
122+
)
123+
124+
formatted_stats += " " * 4 + "...\n"
125+
formatted_stats += (
126+
" " * 3
127+
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
128+
)
129+
130+
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
131+
if dryrun_tracker.tensorrt_graph_count > 0:
132+
min_ops_in_an_engine = min(
133+
trt_subgraph.subgraph_op_count
134+
for trt_subgraph in dryrun_tracker.per_subgraph_data
135+
)
136+
avg_ops_per_engine = (
137+
sum(
138+
trt_subgraph.subgraph_op_count
139+
for trt_subgraph in dryrun_tracker.per_subgraph_data
140+
)
141+
/ dryrun_tracker.tensorrt_graph_count
142+
)
143+
avg_ops_per_engine = round(avg_ops_per_engine, 2)
144+
most_ops_in_an_engine = max(
145+
trt_subgraph.subgraph_op_count
146+
for trt_subgraph in dryrun_tracker.per_subgraph_data
147+
)
148+
149+
formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25
150+
formatted_stats += (
151+
"\n\n"
152+
+ " " * 3
153+
+ "Average Number of Operators per TRT Engine: "
154+
+ f"{avg_ops_per_engine}"
155+
)
156+
157+
formatted_stats += (
158+
"\n"
159+
+ " " * 3
160+
+ "Most Operators in a TRT Engine: "
161+
+ f"{most_ops_in_an_engine}"
162+
)
163+
164+
formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10
165+
formatted_stats += (
166+
"\n\n"
167+
+ " " * 3
168+
+ "- For minimal graph segmentation, select min_block_size="
169+
+ f"{most_ops_in_an_engine} which would generate "
170+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
171+
)
172+
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
173+
formatted_stats += (
174+
"\n"
175+
+ " " * 3
176+
+ "- For moderate graph segmentation, select min_block_size="
177+
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
178+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
179+
)
180+
181+
formatted_stats += (
182+
"\n"
183+
+ " " * 3
184+
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185+
+ f"{min_ops_in_an_engine} which generates "
186+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
187+
)
188+
else:
189+
formatted_stats += (
190+
"\n"
191+
+ " " * 2
192+
+ "Aggregate stats not available since no TRT Engines were generated."
193+
)
194+
195+
dryrun_logger(formatted_stats)
196+
197+
198+
def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
199+
"""Format shapes and dtypes of input Tensors into a readable string"""
200+
formatted_str = ", "
201+
202+
for shape, dtype in zip(shapes, dtypes):
203+
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
204+
205+
return formatted_str[2:-2]

py/torch_tensorrt/dynamo/_compiler.py

+84-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch.export import ExportedProgram
109
from torch_tensorrt._Device import Device
1110
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
@@ -16,6 +15,7 @@
1615
from torch_tensorrt.dynamo._defaults import (
1716
DEBUG,
1817
DEVICE,
18+
DRYRUN,
1919
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
2020
MAX_AUX_STREAMS,
2121
MIN_BLOCK_SIZE,
@@ -29,6 +29,11 @@
2929
VERSION_COMPATIBLE,
3030
WORKSPACE_SIZE,
3131
)
32+
from torch_tensorrt.dynamo._DryRunTracker import (
33+
DryRunTracker,
34+
PerSubgraphData,
35+
dryrun_stats_display,
36+
)
3237
from torch_tensorrt.dynamo.conversion import (
3338
CompilationSettings,
3439
convert_module,
@@ -43,6 +48,8 @@
4348
to_torch_tensorrt_device,
4449
)
4550

51+
import torch_tensorrt
52+
4653
logger = logging.getLogger(__name__)
4754

4855

@@ -75,6 +82,7 @@ def compile(
7582
use_python_runtime: bool = USE_PYTHON_RUNTIME,
7683
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
7784
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
85+
dryrun: bool = DRYRUN,
7886
**kwargs: Any,
7987
) -> torch.fx.GraphModule:
8088
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -131,6 +139,7 @@ def compile(
131139
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
132140
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
133141
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
142+
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
134143
**kwargs: Any,
135144
Returns:
136145
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -192,6 +201,7 @@ def compile(
192201
"use_fast_partitioner": use_fast_partitioner,
193202
"enable_experimental_decompositions": enable_experimental_decompositions,
194203
"require_full_compilation": require_full_compilation,
204+
"dryrun": dryrun,
195205
}
196206

197207
settings = CompilationSettings(**compilation_options)
@@ -215,15 +225,32 @@ def compile_module(
215225
Returns:
216226
Compiled FX GraphModule
217227
"""
228+
dryrun_tracker = DryRunTracker()
218229

219230
# Check the number of supported operations in the graph
220231
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
221232
gm, settings.debug, settings.torch_executed_ops
222233
)
223234

235+
dryrun_tracker.total_ops_in_graph = total_ops
236+
dryrun_tracker.supported_ops_in_graph = num_supported_ops
237+
dryrun_tracker.graph_input_shapes = [
238+
tuple(input_.shape) for input_ in sample_inputs
239+
]
240+
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
241+
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
242+
243+
if settings.dryrun and settings.min_block_size > 1:
244+
logger.info(
245+
"It is recommended to run `dryrun` mode with `min_block_size=1`, "
246+
"for the most thorough analysis"
247+
)
248+
224249
# If the number of supported operations is 0 or less than the block size, skip the subgraph
225250
# TODO: Add condition to second expression below when require_full_compilation is added
226-
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
251+
if num_supported_ops == 0 or (
252+
num_supported_ops < settings.min_block_size and not settings.dryrun
253+
):
227254
logger.warning(
228255
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
229256
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
@@ -274,6 +301,16 @@ def compile_module(
274301
if settings.use_fast_partitioner and "_run_on_acc" not in name:
275302
continue
276303

304+
subgraph_data = PerSubgraphData()
305+
subgraph_data.subgraph_name = name
306+
subgraph_data.subgraph_op_count = len(
307+
[
308+
node
309+
for node in submodule.graph.nodes
310+
if node.op in ("call_function", "call_method", "call_module")
311+
]
312+
)
313+
277314
# Get the submodule inputs for min, opt, max shapes of the graph inputs
278315
submodule_inputs = partitioning.get_submod_inputs(
279316
partitioned_module,
@@ -300,15 +337,51 @@ def compile_module(
300337
name,
301338
)
302339

303-
# Create TRT engines from submodule
304-
trt_module = convert_module(
305-
submodule,
306-
submodule_inputs,
307-
settings=settings,
308-
name=name,
340+
subgraph_data.subgraph_input_dtypes = [
341+
submodule_input.torch_dtype for submodule_input in submodule_inputs
342+
]
343+
subgraph_data.subgraph_input_shapes = [
344+
tuple(submodule_input.shape) for submodule_input in submodule_inputs
345+
]
346+
347+
submodule_outputs = submodule(
348+
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
309349
)
350+
if not isinstance(submodule_outputs, (list, tuple)):
351+
submodule_outputs = [submodule_outputs]
310352

311-
trt_modules[name] = trt_module
353+
subgraph_data.subgraph_output_dtypes = [
354+
submodule_output.dtype for submodule_output in submodule_outputs
355+
]
356+
subgraph_data.subgraph_output_shapes = [
357+
tuple(submodule_output.shape) for submodule_output in submodule_outputs
358+
]
359+
360+
dryrun_tracker.tensorrt_graph_count += 1
361+
dryrun_tracker.per_subgraph_data.append(subgraph_data)
362+
363+
# Create TRT engines from submodule
364+
if not settings.dryrun:
365+
trt_module = convert_module(
366+
submodule,
367+
submodule_inputs,
368+
settings=settings,
369+
name=name,
370+
)
371+
372+
trt_modules[name] = trt_module
373+
374+
sample_outputs = gm(
375+
*get_torch_inputs(sample_inputs, to_torch_device(settings.device))
376+
)
377+
378+
if not isinstance(sample_outputs, (list, tuple)):
379+
sample_outputs = [sample_outputs]
380+
381+
dryrun_tracker.graph_output_shapes = [
382+
tuple(output_.shape) for output_ in sample_outputs
383+
]
384+
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
312385

313386
# Replace all FX Modules with TRT Modules
314387
for name, trt_module in trt_modules.items():
@@ -318,4 +391,6 @@ def compile_module(
318391
if fast_partitioner_failed:
319392
settings.use_fast_partitioner = True
320393

394+
dryrun_stats_display(dryrun_tracker, settings.dryrun)
395+
321396
return partitioned_module

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
USE_FAST_PARTITIONER = True
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
1717
REQUIRE_FULL_COMPILATION = False
18+
DRYRUN = False
1819

1920

2021
def default_device() -> Device:

0 commit comments

Comments
 (0)