1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from functools import partial
5
- from typing import Any , Callable , Sequence
4
+ import unittest
5
+ from typing import Any , Callable , Dict , Optional , Sequence
6
6
7
7
import torch
8
8
import torch ._dynamo as td
9
- from torch ._functorch .aot_autograd import aot_module_simplified , make_boxed_compiler
9
+ import torch .utils ._pytree as pytree
10
+ from torch ._dynamo .utils import detect_fake_mode
11
+ from torch ._functorch .aot_autograd import _aot_export_function
12
+ from torch ._inductor .constant_folding import ConstantFolder , replace_node_with_constant
13
+ from torch ._ops import OpOverload
10
14
from torch_tensorrt .dynamo import CompilationSettings
11
15
from torch_tensorrt .dynamo .compile import compile_module
12
16
from torch_tensorrt .dynamo .lowering ._decompositions import get_decompositions
@@ -33,31 +37,15 @@ def torch_tensorrt_backend(
33
37
34
38
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
35
39
36
- compiled_mod : torch .nn .Module = DEFAULT_BACKEND (gm , sample_inputs , ** kwargs )
37
- return compiled_mod
40
+ return DEFAULT_BACKEND (gm , sample_inputs , ** kwargs )
38
41
39
42
40
43
@td .register_backend (name = "aot_torch_tensorrt_aten" ) # type: ignore[misc]
41
44
def aot_torch_tensorrt_aten_backend (
42
45
gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ], ** kwargs : Any
43
46
) -> torch .nn .Module :
44
47
settings = parse_dynamo_kwargs (kwargs )
45
-
46
- custom_backend = partial (
47
- _pretraced_backend ,
48
- settings = settings ,
49
- )
50
-
51
- # Perform Pre-AOT Lowering for Module-Level Replacement
52
- gm = pre_aot_substitutions (gm )
53
-
54
- # Invoke AOTAutograd to translate operators to aten
55
- return aot_module_simplified (
56
- gm ,
57
- sample_inputs ,
58
- fw_compiler = make_boxed_compiler (custom_backend ),
59
- decompositions = get_decompositions (settings .enable_experimental_decompositions ),
60
- )
48
+ return _pretraced_backend (gm , sample_inputs , settings )
61
49
62
50
63
51
def _pretraced_backend (
@@ -75,22 +63,44 @@ def _pretraced_backend(
75
63
Compiled FX GraphModule
76
64
"""
77
65
try :
78
- logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
66
+ logger .debug ("Pre-AOT Autograd graph:\n " + str (gm .graph ))
67
+
68
+ # Perform Pre-AOT Lowering for Module-Level Replacement
69
+ gm = pre_aot_substitutions (gm )
70
+
71
+ fake_mode = detect_fake_mode (sample_inputs )
72
+
73
+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
74
+ with unittest .mock .patch .object (
75
+ fake_mode , "allow_non_fake_inputs" , True
76
+ ), fake_mode :
77
+ # Invoke AOTAutograd to translate operators to aten
78
+ graph_module = aot_export_for_compile (
79
+ gm ,
80
+ sample_inputs ,
81
+ decompositions = get_decompositions (
82
+ settings .enable_experimental_decompositions
83
+ ),
84
+ )
79
85
80
- trt_compiled = compile_module (
81
- gm ,
82
- sample_inputs ,
83
- settings = settings ,
84
- )
85
- return trt_compiled
86
- except AssertionError :
86
+ logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
87
+
88
+ constant_fold (graph_module )
89
+
90
+ trt_compiled = compile_module (
91
+ graph_module ,
92
+ sample_inputs ,
93
+ settings = settings ,
94
+ )
95
+ return trt_compiled
96
+ except (AssertionError , RuntimeError ):
87
97
if not settings .pass_through_build_failures :
88
98
logger .warning (
89
99
"TRT conversion failed on the subgraph. See trace above. "
90
100
+ "Returning GraphModule forward instead." ,
91
101
exc_info = True ,
92
102
)
93
- return gm . forward
103
+ return gm
94
104
else :
95
105
logger .critical (
96
106
"Halting compilation on build failure since "
@@ -100,3 +110,82 @@ def _pretraced_backend(
100
110
+ "specify pass_through_build_failures=False."
101
111
)
102
112
raise
113
+
114
+
115
+ @torch .utils ._python_dispatch ._disable_current_modes () # type: ignore
116
+ def constant_fold (gm : torch .fx .GraphModule ) -> Any :
117
+ """Adapted from:
118
+ https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119
+
120
+ Folds constants in the graph module, not skipping constructors
121
+
122
+ Modifies the graph in-place and replaces node with constants
123
+ """
124
+ cf = ConstantFolder (gm , skip_constructors = False )
125
+ cf .run ()
126
+
127
+ for node , constant in cf .node_replacements .items ():
128
+ replace_node_with_constant (gm , node , constant )
129
+
130
+ erased_params = []
131
+ for node in gm .graph .nodes :
132
+ if node .op == "get_attr" and len (node .users ) == 0 :
133
+ delattr (gm , node .target )
134
+ erased_params .append (node )
135
+
136
+ for node in erased_params :
137
+ gm .graph .erase_node (node )
138
+
139
+ gm .graph .eliminate_dead_code ()
140
+ gm .graph .lint ()
141
+ gm .recompile ()
142
+
143
+
144
+ def aot_export_for_compile (
145
+ func : torch .fx .GraphModule ,
146
+ args : Sequence [torch .Tensor ],
147
+ * ,
148
+ decompositions : Optional [Dict [OpOverload , Callable [[Any ], Any ]]] = None ,
149
+ ) -> torch .fx .GraphModule :
150
+ """Adapted from:
151
+ https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
152
+
153
+ Removed check for input aliasing in resultant subgraph - TRT is functional-only
154
+
155
+ Exports the function to ATen for torch compile
156
+ """
157
+ # Trace function with input arguments and decompositions
158
+ with torch .no_grad ():
159
+ fx_g , metadata , in_spec , out_spec = _aot_export_function (
160
+ func ,
161
+ args ,
162
+ decompositions = decompositions ,
163
+ )
164
+
165
+ # No input mutations
166
+ if (
167
+ len ([x for x in metadata .input_info if x .mutates_data or x .mutates_metadata ])
168
+ != 0
169
+ ):
170
+ raise RuntimeError (
171
+ f"aot_export_joint_simple does not support input mutations. { str (metadata )} "
172
+ )
173
+ # No pytrees
174
+ if type (in_spec ) == pytree .LeafSpec :
175
+ raise RuntimeError (
176
+ f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={ str (in_spec )} "
177
+ )
178
+ if len ([x for x in in_spec .children_specs if type (x ) != pytree .LeafSpec ]) != 0 :
179
+ raise RuntimeError (
180
+ f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={ str (in_spec )} "
181
+ )
182
+ if type (out_spec ) == pytree .LeafSpec :
183
+ raise RuntimeError (
184
+ f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={ str (out_spec )} "
185
+ )
186
+ if len ([x for x in out_spec .children_specs if type (x ) != pytree .LeafSpec ]) != 0 :
187
+ raise RuntimeError (
188
+ f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={ str (out_spec )} "
189
+ )
190
+
191
+ return fx_g
0 commit comments