Skip to content

Commit d2e4f6d

Browse files
committed
perf: Add lowering passes to improve TRT conversion
- Focus on variance and sum converters, reducing instances of extraneous layers from unnecessary reshapes - Add test cases to validate new additions
1 parent 5de208f commit d2e4f6d

14 files changed

+322
-23
lines changed

docsrc/contributors/writing_dynamo_aten_lowering_passes.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Lowering Pass Requirements
1212
------------
1313

1414
An ATen lowering pass function in Torch-TRT must satisfy two requirements:
15-
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
15+
- The function must take as input a `torch.fx.GraphModule` and a sequence of torch Tensors, `Sequence[torch.Tensor]`, and return the lowered `torch.fx.GraphModule`
1616
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation
1717

1818
See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
@@ -22,7 +22,7 @@ Example Lowering Pass
2222

2323
.. code-block:: python
2424
25-
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
25+
def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
2626
"""Repair scenarios where inputs are also outputs of the graph
2727
2828
TRT does not allow such cases, so we insert a clone (identity) layer
@@ -82,15 +82,15 @@ For instance, to insert the pass at the default location (end of the list), the
8282
.. code-block:: python
8383
8484
@_aten_lowering_pass
85-
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
85+
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
8686
...
8787
8888
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:
8989

9090
.. code-block:: python
9191
9292
@_aten_lowering_pass(index=0)
93-
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
93+
def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule:
9494
...
9595
9696
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
@@ -101,7 +101,7 @@ There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for
101101
print(dump_lowering_passes())
102102
103103
# Apply lowering passes to a GraphModule
104-
apply_lowering_passes(graph_module)
104+
apply_lowering_passes(graph_module, sample_inputs)
105105
106106
# Remove the lowering pass at index 1
107107
_remove_lowering_pass(index=1)

py/torch_tensorrt/dynamo/aten_tracer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ def trace(
2828
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
2929
):
3030
graph_module = export(model, tuple(inputs)).module()
31-
graph_module = apply_lowering_passes(graph_module)
31+
graph_module = apply_lowering_passes(graph_module, inputs)
3232
logger.debug("Post export graph: " + str(graph_module.graph))
3333
return graph_module

py/torch_tensorrt/dynamo/backend/backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _pretraced_backend(
8787

8888
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8989

90-
gm = apply_lowering_passes(gm)
90+
gm = apply_lowering_passes(gm, sample_inputs)
9191

9292
trt_compiled = compile_module(
9393
gm,

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

+3
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
aten.special_log_ndtr,
150150
aten.special_xlog1py,
151151
aten.stack,
152+
aten.std,
152153
aten.t,
153154
aten.tanh_backward,
154155
aten.threshold,
@@ -163,6 +164,8 @@
163164
aten.upsample_bilinear2d,
164165
aten.upsample_bilinear2d.vec,
165166
aten.upsample_nearest2d_backward,
167+
aten.var,
168+
aten.var_mean,
166169
aten.xlogy,
167170
aten.zero,
168171
aten.zero_,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Callable, Dict, Optional
2+
from typing import Any, Callable, Dict, List, Optional
33

44
import torch
55
from torch._decomp import register_decomposition
@@ -135,6 +135,54 @@ def reciprocal_replacement(
135135
return torch.div(1, input_)
136136

137137

138+
@register_torch_trt_decomposition(
139+
torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS
140+
)
141+
def var_decomposition(
142+
input_tensor: torch.Tensor,
143+
dims: Optional[List[int]],
144+
correction: int,
145+
output_dtype: Optional[torch.dtype] = None,
146+
) -> torch.Tensor:
147+
if dims is None:
148+
dims = []
149+
150+
# If the dimensions are empty, variance is taken over all dimensions
151+
if isinstance(dims, (tuple, list)) and len(dims) == 0:
152+
N = input_tensor.numel()
153+
# Otherwise, the number of samples is the product of the dimensions reduced over
154+
else:
155+
N = 1
156+
for dim_i in dims:
157+
N *= input_tensor.shape[dim_i]
158+
159+
# Compute the mean, difference, and correction term as per the formula:
160+
# https://pytorch.org/docs/stable/generated/torch.var.html
161+
162+
# Additionally, prims does not support keepdim, and so we only keep dimensions
163+
# on the first reduction, then remove it for the second
164+
sample_mean = torch.mean(input_tensor, dims, keepdim=True)
165+
diff = input_tensor - sample_mean
166+
squared_diff = diff * diff
167+
variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False)
168+
169+
if correction is None:
170+
correction_term = float(N - 1)
171+
elif isinstance(correction, int):
172+
correction_term = float(N - correction)
173+
elif isinstance(correction, float):
174+
correction_term = float(N) - correction
175+
else:
176+
raise RuntimeError("correction must be int or float")
177+
178+
if correction_term <= 0:
179+
raise RuntimeError(f"correction term was non-positive, got: {correction_term}")
180+
181+
variance = variance_unnormalized / correction_term
182+
183+
return variance
184+
185+
138186
def get_decompositions(
139187
enable_experimental_decompositions: bool = False,
140188
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2-
from typing import Callable, Optional
2+
from typing import Callable, Optional, Sequence, Union
33

44
import torch
55

66
from .constant_folding import constant_fold
7+
from .fuse_prims_broadcast import fuse_prims_broadcast
78
from .pass_manager import DynamoPassManager
89
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
910
from .repair_input_as_output import repair_input_as_output
@@ -13,19 +14,24 @@
1314
remove_input_alias_fixing_clones,
1415
constant_fold,
1516
repair_input_as_output,
17+
fuse_prims_broadcast,
1618
]
1719
)
1820

1921
logger = logging.getLogger(__name__)
2022

2123

22-
LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
24+
LoweringPassSignature = Callable[
25+
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
26+
]
2327

2428

2529
def _aten_lowering_pass(
2630
*args: LoweringPassSignature,
2731
index: Optional[int] = None,
28-
) -> LoweringPassSignature:
32+
) -> Union[
33+
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
34+
]:
2935
"""Adds a lowering pass to the registry, at a specified index if desired
3036
3137
If no index is specified, the lowering pass is inserted at the end of the list
@@ -65,12 +71,14 @@ def _remove_lowering_pass(*, index: int) -> None:
6571
return
6672

6773

68-
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
74+
def apply_lowering_passes(
75+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
76+
) -> torch.fx.GraphModule:
6977
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
7078
logging.debug(
7179
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
7280
)
73-
return ATEN_LOWERING_PASSES(gm)
81+
return ATEN_LOWERING_PASSES(gm, sample_inputs)
7482

7583

7684
def dump_lowering_passes() -> str:

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Sequence
23

34
import torch
45
from torch_tensorrt._utils import sanitized_torch_version
@@ -21,7 +22,9 @@
2122

2223

2324
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
24-
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
25+
def constant_fold(
26+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
27+
) -> torch.fx.GraphModule:
2528
"""Adapted from:
2629
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
2730
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
from torch.fx.passes.shape_prop import ShapeProp
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
# TODO: Add relevant prims to this fusion
14+
def fuse_prims_broadcast(
15+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
16+
) -> torch.fx.GraphModule:
17+
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
18+
modified_graph = False
19+
20+
# Propagate shapes through the graph to determine if broadcast can be resolved
21+
try:
22+
ShapeProp(gm).propagate(*sample_inputs)
23+
except (RuntimeError, AssertionError):
24+
logger.warning(
25+
"Shape Propagation Failed on Graph, skipping fuse_prims_broadcast lowering pass",
26+
exc_info=True,
27+
)
28+
return gm
29+
30+
for node in gm.graph.nodes:
31+
# If the node is a sum prims operator, with broadcast_in_dim being the only consumer
32+
# it is a candidate for fusing
33+
if (
34+
node.target in (torch.ops.prims.sum.default,)
35+
and len(node.users) == 1
36+
and list(node.users)[0].target == torch.ops.prims.broadcast_in_dim.default
37+
):
38+
# Get broadcasted shape, reduced dimensions, and original tensor shape
39+
broadcast_node = list(node.users)[0]
40+
broadcasted_shape = broadcast_node.args[1]
41+
reduced_dims = node.args[1]
42+
original_shape = node.args[0].meta["tensor_meta"].shape
43+
44+
# If the rank of the broadcasted shape is the same as the original
45+
# and the broadcasts are all singletons for the reduced dimensions
46+
# and all of the non-reduced dimensions are identical to the originals
47+
48+
# Then the broadcast is effectively performing a "keep_dim=True" operation
49+
if (
50+
len(broadcasted_shape) == len(original_shape)
51+
and all(broadcasted_shape[i] == 1 for i in reduced_dims)
52+
and all(
53+
broadcasted_shape[j] == original_shape[j]
54+
for j in range(len(original_shape))
55+
if j not in reduced_dims
56+
)
57+
):
58+
# Fuse the operator to its convertible alternative
59+
with gm.graph.inserting_after(broadcast_node):
60+
modified_graph = True
61+
62+
if node.target == torch.ops.prims.sum.default:
63+
fused_node = gm.graph.call_function(
64+
torch.ops.aten.sum.dim_IntList,
65+
args=(node.args[0], reduced_dims, True),
66+
)
67+
68+
# Replace all uses of the placeholder except the cloned node
69+
# with the cloned placeholder
70+
broadcast_node.replace_all_uses_with(
71+
fused_node,
72+
)
73+
74+
# Erase uses of the broadcast node and original
75+
gm.graph.erase_node(broadcast_node)
76+
gm.graph.erase_node(node)
77+
78+
if modified_graph:
79+
gm = clean_up_graph_after_modifications(gm)
80+
logger.debug(f"Fused prims-broadcast paradigm:\n{gm.graph}")
81+
82+
return gm
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, List, Optional
1+
from typing import Any, Callable, List, Optional, Sequence
22

33
import torch
44
from torch.fx.passes.pass_manager import PassManager
@@ -8,22 +8,34 @@ class DynamoPassManager(PassManager): # type: ignore[misc]
88
def __init__(
99
self,
1010
passes: Optional[
11-
List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]
11+
List[
12+
Callable[
13+
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
14+
]
15+
]
1216
] = None,
1317
):
1418
super().__init__(passes)
1519

1620
@classmethod
1721
def build_from_passlist(
1822
cls,
19-
passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]],
23+
passes: Optional[
24+
List[
25+
Callable[
26+
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
27+
]
28+
]
29+
],
2030
) -> Any:
2131
pm = DynamoPassManager(passes)
2232
return pm
2333

2434
def add_pass_with_index(
2535
self,
26-
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
36+
lowering_pass: Callable[
37+
[torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule
38+
],
2739
index: Optional[int] = None,
2840
) -> None:
2941
if index is None:
@@ -35,8 +47,12 @@ def add_pass_with_index(
3547
def remove_pass_with_index(self, index: int) -> None:
3648
del self.passes[index]
3749

38-
def __call__(self, source: Any) -> Any:
39-
return super().__call__(source)
50+
def __call__(self, gm: Any, sample_inputs: Any) -> Any:
51+
self.validate()
52+
out, example_inputs = gm, sample_inputs
53+
for _pass in self.passes:
54+
out = _pass(out, example_inputs)
55+
return out
4056

4157
def __str__(self) -> str:
4258
return str(self.passes)

py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Sequence
23

34
import torch
45
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -9,7 +10,9 @@
910

1011

1112
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
12-
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
def remove_input_alias_fixing_clones(
14+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15+
) -> torch.fx.GraphModule:
1316
"""Remove the auxiliary clone nodes inserted to fix input aliasing
1417
1518
See: https://github.com/pytorch/pytorch/issues/108079

py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Sequence
23

34
import torch
45
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
@@ -9,7 +10,9 @@
910
logger = logging.getLogger(__name__)
1011

1112

12-
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
def repair_input_as_output(
14+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15+
) -> torch.fx.GraphModule:
1316
"""Repair scenarios where inputs are also outputs of the graph
1417
1518
TRT does not allow such cases, so we insert a clone (identity) layer

0 commit comments

Comments
 (0)