Skip to content

Commit ffbef2c

Browse files
committed
fix: Address PR comments and update logging scheme
- Fix test case failures
1 parent 1feccb0 commit ffbef2c

File tree

8 files changed

+177
-112
lines changed

8 files changed

+177
-112
lines changed
+81-59
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22
import math
33
from dataclasses import dataclass, field
4-
from typing import List, Tuple
4+
from typing import Any, Dict, List
55

6-
import torch
6+
from torch_tensorrt.dynamo._settings import CompilationSettings
77

88
logger = logging.getLogger(__name__)
99

@@ -15,18 +15,18 @@ class PerSubgraphData:
1515
Args:
1616
subgraph_name (str): Name of the subgraph in the GraphModule
1717
subgraph_op_count (int): Number of operations in the subgraph
18-
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
19-
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
20-
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
21-
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
18+
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
19+
subgraph_input_dtypes (Any): Input data types of the subgraph
20+
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
21+
subgraph_output_dtypes (Any): Output data types of the subgraph
2222
"""
2323

2424
subgraph_name: str = ""
2525
subgraph_op_count: int = 0
26-
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
27-
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
28-
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
29-
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
26+
subgraph_input_shapes: Any = field(default_factory=list)
27+
subgraph_input_dtypes: Any = field(default_factory=list)
28+
subgraph_output_shapes: Any = field(default_factory=list)
29+
subgraph_output_dtypes: Any = field(default_factory=list)
3030

3131

3232
@dataclass
@@ -36,95 +36,86 @@ class DryRunTracker:
3636
Args:
3737
total_ops_in_graph (int): Total number of operators in graph
3838
supported_ops_in_graph (int): Number of supported operators in graph
39-
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
40-
graph_input_dtypes (List[torch.device]): Input data types of the graph
41-
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
42-
graph_output_dtypes (List[torch.device]): Output data types of the graph
39+
graph_input_shapes (Any): Shapes of input Tensors of the graph
40+
graph_input_dtypes (Any): Input data types of the graph
41+
graph_output_shapes (Any): Shapes of output Tensors of the graph
42+
graph_output_dtypes (Any): Output data types of the graph
4343
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
4444
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45-
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
45+
compilation_settings (CompilationSettings): User Compilation Settings
46+
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
4647
"""
4748

4849
total_ops_in_graph: int = 0
4950
supported_ops_in_graph: int = 0
50-
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
51-
graph_input_dtypes: List[torch.device] = field(default_factory=list)
52-
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
53-
graph_output_dtypes: List[torch.device] = field(default_factory=list)
51+
graph_input_shapes: Any = field(default_factory=list)
52+
graph_input_dtypes: Any = field(default_factory=list)
53+
graph_output_shapes: Any = field(default_factory=list)
54+
graph_output_dtypes: Any = field(default_factory=list)
5455
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
5556
tensorrt_graph_count: int = 0
56-
truncated_long_and_double: bool = False
57+
compilation_settings: CompilationSettings = field(
58+
default_factory=CompilationSettings
59+
)
60+
unsupported_ops: Dict[str, int] = field(default_factory=dict)
5761

5862

5963
def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
60-
"""Displays statistics about the dryrun either to debug logs or info logs"""
61-
# If user specified "dryrun=True", print to info logs, else debug
62-
if dryrun_enabled:
63-
dryrun_logger = logger.info
64-
else:
65-
dryrun_logger = logger.debug
66-
64+
"""Displays statistics about the dryrun either to debug logs or stdout"""
6765
formatted_stats = "\n"
6866

6967
# Print overall stats about the graph, operator counts, etc.
70-
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
68+
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n"
7169
formatted_stats += (
7270
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
7371
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
74-
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
75-
)
76-
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
77-
formatted_stats += (
78-
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
72+
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
7973
)
74+
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
75+
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"
8076

8177
assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
8278

8379
# Print schematic of the graph structure, as in:
8480
#
85-
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
81+
# Inputs: List[Tensor: (1, 3, 224, 224)@float32]
8682
# ...
87-
# TRT Engine #1: _run_on_acc_0
88-
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
89-
# Number of Operators in Engine: 1
90-
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
83+
# TRT Engine #1 - Submodule name: _run_on_acc_0
84+
# Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
85+
# Number of Operators in Engine: 1
86+
# Engine Outputs: Tensor: (1, 64, 112, 112)@float32
9187
# ...
92-
# Outputs: [Tensor: (1, 1000)@float32]
88+
# Outputs: List[Tensor: (1, 1000)@float32]
9389
#
9490
formatted_stats += " " * 2 + "Graph Structure:\n\n"
9591
formatted_stats += (
9692
" " * 3
97-
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
93+
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
9894
)
9995

10096
for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
101-
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
102-
trt_subgraph_data.subgraph_input_shapes
103-
)
104-
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
105-
trt_subgraph_data.subgraph_output_shapes
106-
)
10797
formatted_stats += " " * 4 + "...\n"
10898
formatted_stats += (
109-
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
99+
" " * 4
100+
+ f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n"
110101
)
111102
formatted_stats += (
112103
" " * 5
113-
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
104+
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
114105
)
115106
formatted_stats += (
116107
" " * 5
117108
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
118109
)
119110
formatted_stats += (
120111
" " * 5
121-
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
112+
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
122113
)
123114

124115
formatted_stats += " " * 4 + "...\n"
125116
formatted_stats += (
126117
" " * 3
127-
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
118+
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
128119
)
129120

130121
# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
@@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
167158
+ " " * 3
168159
+ "- For minimal graph segmentation, select min_block_size="
169160
+ f"{most_ops_in_an_engine} which would generate "
170-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
161+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)"
171162
)
172163
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
173164
formatted_stats += (
174165
"\n"
175166
+ " " * 3
176167
+ "- For moderate graph segmentation, select min_block_size="
177168
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
178-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
169+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)"
179170
)
180171

181172
formatted_stats += (
182173
"\n"
183174
+ " " * 3
184175
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
185176
+ f"{min_ops_in_an_engine} which generates "
186-
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
177+
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)"
187178
)
188179
else:
189180
formatted_stats += (
@@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
192183
+ "Aggregate stats not available since no TRT Engines were generated."
193184
)
194185

195-
dryrun_logger(formatted_stats)
186+
# If user specified "dryrun=True", print to stdout, else debug
187+
if dryrun_enabled:
188+
print(formatted_stats)
189+
else:
190+
logger.debug(formatted_stats)
196191

197192

198-
def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
193+
def input_formatter(shapes: Any, dtypes: Any) -> str:
199194
"""Format shapes and dtypes of input Tensors into a readable string"""
200-
formatted_str = ", "
201195

202-
for shape, dtype in zip(shapes, dtypes):
203-
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
196+
def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
197+
"""Helper for input formatter"""
198+
# Base case - single shape, single dtype
199+
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
200+
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
201+
202+
# Shapes is a sequence
203+
elif isinstance(shapes, (list, tuple)):
204+
formatted_str = "List[" if isinstance(shapes, list) else "Tuple("
205+
for shape, dtype in zip(shapes, dtypes):
206+
formatted_str += input_formatter_helper(shape, dtype)
207+
formatted_str = formatted_str[:-2] + (
208+
"], " if isinstance(shapes, list) else "), "
209+
)
210+
return formatted_str
211+
212+
# Shapes is a dictionary
213+
elif isinstance(shapes, dict):
214+
formatted_str = "Dict{"
215+
216+
for key, shape in shapes.items():
217+
formatted_str += input_formatter_helper(shape, dtypes[key])
218+
219+
formatted_str = formatted_str[:-2] + "}, "
220+
return formatted_str
221+
222+
else:
223+
raise ValueError(
224+
f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing."
225+
)
204226

205-
return formatted_str[2:-2]
227+
return input_formatter_helper(shapes, dtypes)[:-2]

py/torch_tensorrt/dynamo/_compiler.py

+30-25
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
4343
from torch_tensorrt.dynamo.utils import (
4444
get_torch_inputs,
45+
parse_complex_tensor_structs,
4546
prepare_inputs,
4647
set_log_level,
4748
to_torch_device,
@@ -234,11 +235,13 @@ def compile_module(
234235

235236
dryrun_tracker.total_ops_in_graph = total_ops
236237
dryrun_tracker.supported_ops_in_graph = num_supported_ops
237-
dryrun_tracker.graph_input_shapes = [
238-
tuple(input_.shape) for input_ in sample_inputs
239-
]
240-
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
241-
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
238+
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
239+
sample_inputs, "shape", tuple
240+
)
241+
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
242+
sample_inputs, "torch_dtype"
243+
)
244+
dryrun_tracker.compilation_settings = settings
242245

243246
if settings.dryrun and settings.min_block_size > 1:
244247
logger.info(
@@ -267,7 +270,7 @@ def compile_module(
267270
# If specified, try using the fast partitioner and fall back to the global one on failure
268271
if settings.use_fast_partitioner:
269272
try:
270-
partitioned_module = partitioning.fast_partition(
273+
partitioned_module, supported_ops = partitioning.fast_partition(
271274
gm,
272275
verbose=settings.debug,
273276
min_block_size=settings.min_block_size,
@@ -284,13 +287,15 @@ def compile_module(
284287
settings.use_fast_partitioner = False
285288

286289
if not settings.use_fast_partitioner:
287-
partitioned_module = partitioning.global_partition(
290+
partitioned_module, supported_ops = partitioning.global_partition(
288291
gm,
289292
verbose=settings.debug,
290293
min_block_size=settings.min_block_size,
291294
torch_executed_ops=settings.torch_executed_ops,
292295
)
293296

297+
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
298+
294299
# Store TRT replicas of Torch subgraphs
295300
trt_modules = {}
296301
# Iterate over all components that can be accelerated
@@ -337,25 +342,23 @@ def compile_module(
337342
name,
338343
)
339344

340-
subgraph_data.subgraph_input_dtypes = [
341-
submodule_input.torch_dtype for submodule_input in submodule_inputs
342-
]
343-
subgraph_data.subgraph_input_shapes = [
344-
tuple(submodule_input.shape) for submodule_input in submodule_inputs
345-
]
345+
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
346+
submodule_inputs, "shape", tuple
347+
)
348+
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
349+
submodule_inputs, "torch_dtype"
350+
)
346351

347352
submodule_outputs = submodule(
348353
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
349354
)
350-
if not isinstance(submodule_outputs, (list, tuple)):
351-
submodule_outputs = [submodule_outputs]
352355

353-
subgraph_data.subgraph_output_dtypes = [
354-
submodule_output.dtype for submodule_output in submodule_outputs
355-
]
356-
subgraph_data.subgraph_output_shapes = [
357-
tuple(submodule_output.shape) for submodule_output in submodule_outputs
358-
]
356+
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
357+
submodule_outputs, "shape", tuple
358+
)
359+
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
360+
submodule_outputs, "dtype"
361+
)
359362

360363
dryrun_tracker.tensorrt_graph_count += 1
361364
dryrun_tracker.per_subgraph_data.append(subgraph_data)
@@ -378,10 +381,12 @@ def compile_module(
378381
if not isinstance(sample_outputs, (list, tuple)):
379382
sample_outputs = [sample_outputs]
380383

381-
dryrun_tracker.graph_output_shapes = [
382-
tuple(output_.shape) for output_ in sample_outputs
383-
]
384-
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
384+
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
385+
sample_outputs, "shape", tuple
386+
)
387+
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
388+
sample_outputs, "dtype"
389+
)
385390

386391
# Replace all FX Modules with TRT Modules
387392
for name, trt_module in trt_modules.items():

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def partition(
248248
min_block_size: int = MIN_BLOCK_SIZE,
249249
torch_executed_ops: Collection[Target] = set(),
250250
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
251-
) -> torch.fx.GraphModule:
251+
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
252252
"""Partition an FX GraphModule with aten ops into TRT engines
253253
Partitioning is based on converter operator support
254254
@@ -259,7 +259,7 @@ def partition(
259259
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
260260
require_full_compilation: Require that all computational operators be run in TRT
261261
Returns:
262-
torch.fx.GraphModule
262+
torch.fx.GraphModule, OpSupportTester
263263
"""
264264
# Ensure graph is clean prior to partitioning
265265
gm.graph.eliminate_dead_code()
@@ -280,4 +280,4 @@ def partition(
280280
if verbose:
281281
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
282282

283-
return partitioned_graph
283+
return partitioned_graph, supported_ops

0 commit comments

Comments
 (0)