Skip to content

Commit 0d85466

Browse files
committed
fix: Improve partitioning + lowering systems
- Improve torch.compile Dynamo partitioning system by incorporating key arguments including `min_block_size` and `torch_executed_ops`, which are available for use in TorchScript - Improve torch.compile lowering system by adding key new decompositions to improve coverage and reduce the number of unique operators requiring implementation - Update testing framework to use utilities, reducing code replication - Add extensive testing of new partitioning system and lowering phases
1 parent 7433d4b commit 0d85466

File tree

9 files changed

+390
-71
lines changed

9 files changed

+390
-71
lines changed

py/torch_tensorrt/dynamo/torch_compile/__init__.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch_tensorrt
55
from functools import partial
66

7-
from typing import Any
7+
from typing import Any, Sequence
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

@@ -15,7 +15,7 @@
1515
PRECISION,
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
18-
MAX_NUM_TRT_ENGINES,
18+
MIN_BLOCK_SIZE,
1919
)
2020

2121

@@ -41,7 +41,7 @@ def compile(
4141
calibrator=None,
4242
truncate_long_and_double=False,
4343
require_full_compilation=False,
44-
min_block_size=3,
44+
min_block_size=MIN_BLOCK_SIZE,
4545
torch_executed_ops=[],
4646
torch_executed_modules=[],
4747
**kwargs,
@@ -50,7 +50,7 @@ def compile(
5050
logger.warn(
5151
"The Dynamo backend is an experimental feature, for which only the "
5252
+ "following arguments are supported: "
53-
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
53+
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
5454
)
5555

5656
if not isinstance(inputs, collections.abc.Sequence):
@@ -80,6 +80,8 @@ def compile(
8080
precision=lower_precision,
8181
debug=debug,
8282
workspace_size=workspace_size,
83+
min_block_size=min_block_size,
84+
torch_executed_ops=torch_executed_ops,
8385
**kwargs,
8486
)
8587

@@ -100,7 +102,8 @@ def create_backend(
100102
precision: LowerPrecision = PRECISION,
101103
debug: bool = DEBUG,
102104
workspace_size: int = MAX_WORKSPACE_SIZE,
103-
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
105+
min_block_size: int = MIN_BLOCK_SIZE,
106+
torch_executed_ops: Sequence[str] = set(),
104107
**kwargs,
105108
):
106109
"""Create torch.compile backend given specified arguments
@@ -117,7 +120,8 @@ def create_backend(
117120
debug=debug,
118121
precision=precision,
119122
workspace_size=workspace_size,
120-
max_num_trt_engines=max_num_trt_engines,
123+
min_block_size=min_block_size,
124+
torch_executed_ops=torch_executed_ops,
121125
)
122126

123127
return partial(

py/torch_tensorrt/dynamo/torch_compile/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
7-
MAX_NUM_TRT_ENGINES = 200
7+
MIN_BLOCK_SIZE = 3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
2+
from typing import Sequence
23

34
from torch_tensorrt.fx.utils import LowerPrecision
45
from torch_tensorrt.dynamo.torch_compile._defaults import (
56
PRECISION,
67
DEBUG,
78
MAX_WORKSPACE_SIZE,
8-
MAX_NUM_TRT_ENGINES,
9+
MIN_BLOCK_SIZE,
910
)
1011

1112

@@ -14,4 +15,5 @@ class CompilationSettings:
1415
precision: LowerPrecision = PRECISION
1516
debug: bool = DEBUG
1617
workspace_size: int = MAX_WORKSPACE_SIZE
17-
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES
18+
min_block_size: int = MIN_BLOCK_SIZE
19+
torch_executed_ops: Sequence[str] = field(default_factory=set)

py/torch_tensorrt/dynamo/torch_compile/backends.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ def compile_module(
9090
"""
9191
# Partition module into components that can be TRT-accelerated
9292
partitioned_module = partition(
93-
gm, verbose=settings.debug, max_num_trt_engines=settings.max_num_trt_engines
93+
gm,
94+
verbose=settings.debug,
95+
max_num_trt_engines=settings.max_num_trt_engines,
96+
torch_executed_ops=settings.torch_executed_ops,
9497
)
9598

9699
# Iterate over all components that can be accelerated

py/torch_tensorrt/dynamo/torch_compile/lowering/_decompositions.py

+15
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,20 @@ def inplace_op(*args, **kwargs):
4141
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
4242

4343

44+
@register_decomposition(aten.std, registry=DECOMPOSITIONS)
45+
def std_replacement(*args, **kwargs) -> torch.Tensor:
46+
return torch.sqrt(torch.var(*args, **kwargs))
47+
48+
49+
@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS)
50+
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
51+
return torch.reciprocal(torch.sqrt(*args, **kwargs))
52+
53+
54+
@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
55+
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
56+
return x
57+
58+
4459
def get_decompositions():
4560
return DECOMPOSITIONS

py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py

+120-31
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,181 @@
1-
from typing import Dict, Optional, Sequence
1+
import logging
2+
from typing import Dict, List, Optional, Sequence
23

34
import torch
45

5-
from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
6-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
6+
from torch_tensorrt.dynamo.torch_compile._defaults import MIN_BLOCK_SIZE
7+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
8+
from torch.fx.graph_module import GraphModule
9+
from torch.fx.node import _get_qualified_name
710
from torch.fx.passes.operator_support import OperatorSupport
811

912
from torch_tensorrt.fx.converter_registry import CONVERTERS
1013

1114

15+
logger = logging.getLogger(__name__)
16+
17+
18+
class TRTPartitioner(CapabilityBasedPartitioner):
19+
"""Partitioner to split an FX graph into subgraphs based on operator support
20+
21+
Args:
22+
graph_module: FX GraphModule to partition
23+
operator_support: OperatorSupport class describing allowed operators
24+
non_compute_ops: Operators which are not considered computational (e.g. getattr)
25+
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
26+
Generally useful for module-level exclusion ops which are intensive despite being single functions
27+
min_block_size: Minimum number of computational operators per block
28+
Returns:
29+
torch.fx.GraphModule
30+
"""
31+
32+
def __init__(
33+
self,
34+
graph_module: GraphModule,
35+
operator_support: OperatorSupport,
36+
*,
37+
non_compute_ops: Optional[Sequence[str]] = None,
38+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
39+
min_block_size=MIN_BLOCK_SIZE,
40+
) -> None:
41+
super().__init__(
42+
graph_module,
43+
operator_support,
44+
allows_single_node_partition=True,
45+
non_compute_ops=non_compute_ops,
46+
allowed_single_node_partition_ops=allowed_single_node_partition_ops,
47+
)
48+
49+
self.min_block_size = min_block_size
50+
51+
def propose_partitions(self) -> List[Partition]:
52+
# Propose partitions using the default, then refine the results
53+
initial_proposed_partitions = super().propose_partitions()
54+
partitions = {i: part for i, part in enumerate(initial_proposed_partitions)}
55+
56+
# For each partition, determine whether or not the number of computational operators
57+
# exceeds the threshold, and if not, remove that partition
58+
partitions_to_remove = {}
59+
for id, partition in partitions.items():
60+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
61+
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
62+
exempted_partition = False
63+
64+
compute_node_count = 0
65+
for node in partition.nodes:
66+
# Partitions are exempted from min_block_size if they contain an allowed single-node op
67+
if (
68+
node.op == "call_function"
69+
and _get_qualified_name(node.target)
70+
in self.allowed_single_node_partition_ops
71+
):
72+
exempted_partition = True
73+
break
74+
elif (
75+
node.op == "call_function"
76+
and _get_qualified_name(node.target) not in non_compute_ops
77+
):
78+
compute_node_count += 1
79+
80+
if compute_node_count < self.min_block_size and not exempted_partition:
81+
partitions_to_remove[id] = compute_node_count
82+
83+
# Remove any nodes violating the criteria specified by the user
84+
for id, count in partitions_to_remove.items():
85+
logger.debug(
86+
f"Removing partition which has {count} < {self.min_block_size} computational operators"
87+
)
88+
del partitions[id]
89+
90+
return [partitions[k] for k in sorted(partitions.keys())]
91+
92+
def partition_and_fuse(self) -> GraphModule:
93+
partitions = self.propose_partitions()
94+
fused_gm = self.fuse_partitions(partitions)
95+
return fused_gm
96+
97+
1298
class TorchTensorRTOperatorSupport(OperatorSupport):
1399
"""Class to determine whether operators within a module are supported"""
14100

15-
def __init__(self, support_dict=None):
101+
def __init__(self, support_dict=None, torch_executed_ops=set()):
16102
super().__init__(support_dict)
17103

18104
# Initialize sets of supported/unsupported operators
19105
self.supported_operators = set()
20106
self.unsupported_operators = set()
107+
self.torch_executed_ops = torch_executed_ops
21108

22109
def is_node_supported(
23110
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
24111
) -> bool:
25-
if node.target in CONVERTERS.keys():
26-
# If node is a proper computational node, store the operator
112+
node_name = (
113+
_get_qualified_name(node.target)
114+
if not isinstance(node.target, str)
115+
else node.target
116+
)
117+
118+
if (
119+
node.target in CONVERTERS.keys()
120+
and node_name not in self.torch_executed_ops
121+
):
122+
# If node is a proper, supported computational node, store the operator
27123
if not node.is_impure():
28-
node_name = node._pretty_print_target(node.target)
29124
self.supported_operators.add(node_name)
30125

31126
return True
32127
else:
33128
if not node.is_impure():
34-
node_name = node._pretty_print_target(node.target)
35129
self.unsupported_operators.add(node_name)
36130

37131
return False
38132

39133
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
40134
if num_trt_blocks is not None:
41-
print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}")
135+
logger.debug(
136+
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
137+
)
42138

43-
print("\nSupported Nodes:")
139+
logger.debug("\nSupported Nodes:")
44140
for node_name in self.supported_operators:
45-
print("-", node_name)
141+
logger.debug("-", node_name)
46142

47143
if len(self.unsupported_operators) != 0:
48-
print("\nUnsupported Nodes:")
144+
logger.debug("\nUnsupported or Excluded Nodes:")
49145
for node_name in self.unsupported_operators:
50-
print("-", node_name)
51-
print("\n")
146+
logger.debug("-", node_name)
147+
logger.debug("\n")
52148
else:
53-
print("\nAll Nodes Supported\n")
149+
logger.debug("\nAll Nodes Supported\n")
54150

55151

56152
def partition(
57153
gm: torch.fx.GraphModule,
58154
verbose: bool = True,
59-
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
155+
min_block_size: int = MIN_BLOCK_SIZE,
156+
torch_executed_ops: Sequence[str] = set(),
60157
) -> torch.fx.GraphModule:
61158
"""Partition an FX GraphModule with aten ops into TRT engines
62159
Partitioning is based on converter operator support
63160
64161
Args:
65162
gm: FX GraphModule to partition
66163
verbose: Bool representing whether to print operator support
67-
max_num_trt_engines: Maximum number of allowed TRT engines in partitioning
164+
min_block_size: Minimum number of operators per TRT-Engine Block
165+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
68166
Returns:
69167
torch.fx.GraphModule
70168
"""
71-
supported_ops = TorchTensorRTOperatorSupport()
72-
partitioner = CapabilityBasedPartitioner(gm, supported_ops)
169+
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
170+
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
73171

74-
# Determine partitions, and raise error if the degree of partitioning
75-
# exceeds a specified threshold
172+
# Determine partitions based on user specifications and operator support
173+
# Then, fuse partitions and display overview of supported/unsupported operators
76174
partitions = partitioner.propose_partitions()
77-
num_blocks = len(partitions)
78-
if num_blocks > max_num_trt_engines:
79-
raise AssertionError(
80-
f"The graph module has {num_blocks} TRT Engines which is larger than the "
81-
+ f"threshold={max_num_trt_engines}. Falling back to non-TRT module."
82-
)
83-
84-
# Fuse partitions and display overview of supported/unsupported operators
85175
fused_graph = partitioner.fuse_partitions(partitions)
86-
num_blocks = len(partitions)
87176

88177
if verbose:
89-
supported_ops.print_support_overview(num_blocks)
178+
supported_ops.print_support_overview(len(partitions))
90179

91180
return fused_graph
92181

0 commit comments

Comments
 (0)