|
1 |
| -from typing import Dict, Optional, Sequence |
| 1 | +import logging |
| 2 | +from typing import Dict, List, Optional, Sequence |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 |
|
5 |
| -from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES |
6 |
| -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner |
| 6 | +from torch_tensorrt.dynamo.torch_compile._defaults import MIN_BLOCK_SIZE |
| 7 | +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition |
| 8 | +from torch.fx.graph_module import GraphModule |
| 9 | +from torch.fx.node import _get_qualified_name |
7 | 10 | from torch.fx.passes.operator_support import OperatorSupport
|
8 | 11 |
|
9 | 12 | from torch_tensorrt.fx.converter_registry import CONVERTERS
|
10 | 13 |
|
11 | 14 |
|
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class TRTPartitioner(CapabilityBasedPartitioner): |
| 19 | + """Partitioner to split an FX graph into subgraphs based on operator support |
| 20 | +
|
| 21 | + Args: |
| 22 | + graph_module: FX GraphModule to partition |
| 23 | + operator_support: OperatorSupport class describing allowed operators |
| 24 | + non_compute_ops: Operators which are not considered computational (e.g. getattr) |
| 25 | + allowed_single_node_partition_ops: Nodes which can be included in single-node partitons. |
| 26 | + Generally useful for module-level exclusion ops which are intensive despite being single functions |
| 27 | + min_block_size: Minimum number of computational operators per block |
| 28 | + Returns: |
| 29 | + torch.fx.GraphModule |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + graph_module: GraphModule, |
| 35 | + operator_support: OperatorSupport, |
| 36 | + *, |
| 37 | + non_compute_ops: Optional[Sequence[str]] = None, |
| 38 | + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, |
| 39 | + min_block_size=MIN_BLOCK_SIZE, |
| 40 | + ) -> None: |
| 41 | + super().__init__( |
| 42 | + graph_module, |
| 43 | + operator_support, |
| 44 | + allows_single_node_partition=True, |
| 45 | + non_compute_ops=non_compute_ops, |
| 46 | + allowed_single_node_partition_ops=allowed_single_node_partition_ops, |
| 47 | + ) |
| 48 | + |
| 49 | + self.min_block_size = min_block_size |
| 50 | + |
| 51 | + def propose_partitions(self) -> List[Partition]: |
| 52 | + # Propose partitions using the default, then refine the results |
| 53 | + initial_proposed_partitions = super().propose_partitions() |
| 54 | + partitions = {i: part for i, part in enumerate(initial_proposed_partitions)} |
| 55 | + |
| 56 | + # For each partition, determine whether or not the number of computational operators |
| 57 | + # exceeds the threshold, and if not, remove that partition |
| 58 | + partitions_to_remove = {} |
| 59 | + for id, partition in partitions.items(): |
| 60 | + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} |
| 61 | + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) |
| 62 | + exempted_partition = False |
| 63 | + |
| 64 | + compute_node_count = 0 |
| 65 | + for node in partition.nodes: |
| 66 | + # Partitions are exempted from min_block_size if they contain an allowed single-node op |
| 67 | + if ( |
| 68 | + node.op == "call_function" |
| 69 | + and _get_qualified_name(node.target) |
| 70 | + in self.allowed_single_node_partition_ops |
| 71 | + ): |
| 72 | + exempted_partition = True |
| 73 | + break |
| 74 | + elif ( |
| 75 | + node.op == "call_function" |
| 76 | + and _get_qualified_name(node.target) not in non_compute_ops |
| 77 | + ): |
| 78 | + compute_node_count += 1 |
| 79 | + |
| 80 | + if compute_node_count < self.min_block_size and not exempted_partition: |
| 81 | + partitions_to_remove[id] = compute_node_count |
| 82 | + |
| 83 | + # Remove any nodes violating the criteria specified by the user |
| 84 | + for id, count in partitions_to_remove.items(): |
| 85 | + logger.debug( |
| 86 | + f"Removing partition which has {count} < {self.min_block_size} computational operators" |
| 87 | + ) |
| 88 | + del partitions[id] |
| 89 | + |
| 90 | + return [partitions[k] for k in sorted(partitions.keys())] |
| 91 | + |
| 92 | + def partition_and_fuse(self) -> GraphModule: |
| 93 | + partitions = self.propose_partitions() |
| 94 | + fused_gm = self.fuse_partitions(partitions) |
| 95 | + return fused_gm |
| 96 | + |
| 97 | + |
12 | 98 | class TorchTensorRTOperatorSupport(OperatorSupport):
|
13 | 99 | """Class to determine whether operators within a module are supported"""
|
14 | 100 |
|
15 |
| - def __init__(self, support_dict=None): |
| 101 | + def __init__(self, support_dict=None, torch_executed_ops=set()): |
16 | 102 | super().__init__(support_dict)
|
17 | 103 |
|
18 | 104 | # Initialize sets of supported/unsupported operators
|
19 | 105 | self.supported_operators = set()
|
20 | 106 | self.unsupported_operators = set()
|
| 107 | + self.torch_executed_ops = torch_executed_ops |
21 | 108 |
|
22 | 109 | def is_node_supported(
|
23 | 110 | self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
|
24 | 111 | ) -> bool:
|
25 |
| - if node.target in CONVERTERS.keys(): |
26 |
| - # If node is a proper computational node, store the operator |
| 112 | + node_name = ( |
| 113 | + _get_qualified_name(node.target) |
| 114 | + if not isinstance(node.target, str) |
| 115 | + else node.target |
| 116 | + ) |
| 117 | + |
| 118 | + if ( |
| 119 | + node.target in CONVERTERS.keys() |
| 120 | + and node_name not in self.torch_executed_ops |
| 121 | + ): |
| 122 | + # If node is a proper, supported computational node, store the operator |
27 | 123 | if not node.is_impure():
|
28 |
| - node_name = node._pretty_print_target(node.target) |
29 | 124 | self.supported_operators.add(node_name)
|
30 | 125 |
|
31 | 126 | return True
|
32 | 127 | else:
|
33 | 128 | if not node.is_impure():
|
34 |
| - node_name = node._pretty_print_target(node.target) |
35 | 129 | self.unsupported_operators.add(node_name)
|
36 | 130 |
|
37 | 131 | return False
|
38 | 132 |
|
39 | 133 | def print_support_overview(self, num_trt_blocks: Optional[int] = None):
|
40 | 134 | if num_trt_blocks is not None:
|
41 |
| - print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}") |
| 135 | + logger.debug( |
| 136 | + f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" |
| 137 | + ) |
42 | 138 |
|
43 |
| - print("\nSupported Nodes:") |
| 139 | + logger.debug("\nSupported Nodes:") |
44 | 140 | for node_name in self.supported_operators:
|
45 |
| - print("-", node_name) |
| 141 | + logger.debug("-", node_name) |
46 | 142 |
|
47 | 143 | if len(self.unsupported_operators) != 0:
|
48 |
| - print("\nUnsupported Nodes:") |
| 144 | + logger.debug("\nUnsupported or Excluded Nodes:") |
49 | 145 | for node_name in self.unsupported_operators:
|
50 |
| - print("-", node_name) |
51 |
| - print("\n") |
| 146 | + logger.debug("-", node_name) |
| 147 | + logger.debug("\n") |
52 | 148 | else:
|
53 |
| - print("\nAll Nodes Supported\n") |
| 149 | + logger.debug("\nAll Nodes Supported\n") |
54 | 150 |
|
55 | 151 |
|
56 | 152 | def partition(
|
57 | 153 | gm: torch.fx.GraphModule,
|
58 | 154 | verbose: bool = True,
|
59 |
| - max_num_trt_engines: int = MAX_NUM_TRT_ENGINES, |
| 155 | + min_block_size: int = MIN_BLOCK_SIZE, |
| 156 | + torch_executed_ops: Sequence[str] = set(), |
60 | 157 | ) -> torch.fx.GraphModule:
|
61 | 158 | """Partition an FX GraphModule with aten ops into TRT engines
|
62 | 159 | Partitioning is based on converter operator support
|
63 | 160 |
|
64 | 161 | Args:
|
65 | 162 | gm: FX GraphModule to partition
|
66 | 163 | verbose: Bool representing whether to print operator support
|
67 |
| - max_num_trt_engines: Maximum number of allowed TRT engines in partitioning |
| 164 | + min_block_size: Minimum number of operators per TRT-Engine Block |
| 165 | + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
68 | 166 | Returns:
|
69 | 167 | torch.fx.GraphModule
|
70 | 168 | """
|
71 |
| - supported_ops = TorchTensorRTOperatorSupport() |
72 |
| - partitioner = CapabilityBasedPartitioner(gm, supported_ops) |
| 169 | + supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops) |
| 170 | + partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size) |
73 | 171 |
|
74 |
| - # Determine partitions, and raise error if the degree of partitioning |
75 |
| - # exceeds a specified threshold |
| 172 | + # Determine partitions based on user specifications and operator support |
| 173 | + # Then, fuse partitions and display overview of supported/unsupported operators |
76 | 174 | partitions = partitioner.propose_partitions()
|
77 |
| - num_blocks = len(partitions) |
78 |
| - if num_blocks > max_num_trt_engines: |
79 |
| - raise AssertionError( |
80 |
| - f"The graph module has {num_blocks} TRT Engines which is larger than the " |
81 |
| - + f"threshold={max_num_trt_engines}. Falling back to non-TRT module." |
82 |
| - ) |
83 |
| - |
84 |
| - # Fuse partitions and display overview of supported/unsupported operators |
85 | 175 | fused_graph = partitioner.fuse_partitions(partitions)
|
86 |
| - num_blocks = len(partitions) |
87 | 176 |
|
88 | 177 | if verbose:
|
89 |
| - supported_ops.print_support_overview(num_blocks) |
| 178 | + supported_ops.print_support_overview(len(partitions)) |
90 | 179 |
|
91 | 180 | return fused_graph
|
92 | 181 |
|
|
0 commit comments