Skip to content

Commit 100191b

Browse files
authored
feat: Add preliminary support for freezing tensors in Dynamo (#2128)
1 parent 7a4288e commit 100191b

File tree

19 files changed

+427
-132
lines changed

19 files changed

+427
-132
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+119-30
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

33
import logging
4-
from functools import partial
5-
from typing import Any, Callable, Sequence
4+
import unittest
5+
from typing import Any, Callable, Dict, Optional, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
9+
import torch.utils._pytree as pytree
10+
from torch._dynamo.utils import detect_fake_mode
11+
from torch._functorch.aot_autograd import _aot_export_function
12+
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
13+
from torch._ops import OpOverload
1014
from torch_tensorrt.dynamo import CompilationSettings
1115
from torch_tensorrt.dynamo.compile import compile_module
1216
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
@@ -33,31 +37,15 @@ def torch_tensorrt_backend(
3337

3438
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3539

36-
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
37-
return compiled_mod
40+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3841

3942

4043
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
4144
def aot_torch_tensorrt_aten_backend(
4245
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
4346
) -> torch.nn.Module:
4447
settings = parse_dynamo_kwargs(kwargs)
45-
46-
custom_backend = partial(
47-
_pretraced_backend,
48-
settings=settings,
49-
)
50-
51-
# Perform Pre-AOT Lowering for Module-Level Replacement
52-
gm = pre_aot_substitutions(gm)
53-
54-
# Invoke AOTAutograd to translate operators to aten
55-
return aot_module_simplified(
56-
gm,
57-
sample_inputs,
58-
fw_compiler=make_boxed_compiler(custom_backend),
59-
decompositions=get_decompositions(settings.enable_experimental_decompositions),
60-
)
48+
return _pretraced_backend(gm, sample_inputs, settings)
6149

6250

6351
def _pretraced_backend(
@@ -75,22 +63,44 @@ def _pretraced_backend(
7563
Compiled FX GraphModule
7664
"""
7765
try:
78-
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
66+
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))
67+
68+
# Perform Pre-AOT Lowering for Module-Level Replacement
69+
gm = pre_aot_substitutions(gm)
70+
71+
fake_mode = detect_fake_mode(sample_inputs)
72+
73+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
74+
with unittest.mock.patch.object(
75+
fake_mode, "allow_non_fake_inputs", True
76+
), fake_mode:
77+
# Invoke AOTAutograd to translate operators to aten
78+
graph_module = aot_export_for_compile(
79+
gm,
80+
sample_inputs,
81+
decompositions=get_decompositions(
82+
settings.enable_experimental_decompositions
83+
),
84+
)
7985

80-
trt_compiled = compile_module(
81-
gm,
82-
sample_inputs,
83-
settings=settings,
84-
)
85-
return trt_compiled
86-
except AssertionError:
86+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
87+
88+
constant_fold(graph_module)
89+
90+
trt_compiled = compile_module(
91+
graph_module,
92+
sample_inputs,
93+
settings=settings,
94+
)
95+
return trt_compiled
96+
except (AssertionError, RuntimeError):
8797
if not settings.pass_through_build_failures:
8898
logger.warning(
8999
"TRT conversion failed on the subgraph. See trace above. "
90100
+ "Returning GraphModule forward instead.",
91101
exc_info=True,
92102
)
93-
return gm.forward
103+
return gm
94104
else:
95105
logger.critical(
96106
"Halting compilation on build failure since "
@@ -100,3 +110,82 @@ def _pretraced_backend(
100110
+ "specify pass_through_build_failures=False."
101111
)
102112
raise
113+
114+
115+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
116+
def constant_fold(gm: torch.fx.GraphModule) -> Any:
117+
"""Adapted from:
118+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119+
120+
Folds constants in the graph module, not skipping constructors
121+
122+
Modifies the graph in-place and replaces node with constants
123+
"""
124+
cf = ConstantFolder(gm, skip_constructors=False)
125+
cf.run()
126+
127+
for node, constant in cf.node_replacements.items():
128+
replace_node_with_constant(gm, node, constant)
129+
130+
erased_params = []
131+
for node in gm.graph.nodes:
132+
if node.op == "get_attr" and len(node.users) == 0:
133+
delattr(gm, node.target)
134+
erased_params.append(node)
135+
136+
for node in erased_params:
137+
gm.graph.erase_node(node)
138+
139+
gm.graph.eliminate_dead_code()
140+
gm.graph.lint()
141+
gm.recompile()
142+
143+
144+
def aot_export_for_compile(
145+
func: torch.fx.GraphModule,
146+
args: Sequence[torch.Tensor],
147+
*,
148+
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
149+
) -> torch.fx.GraphModule:
150+
"""Adapted from:
151+
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
152+
153+
Removed check for input aliasing in resultant subgraph - TRT is functional-only
154+
155+
Exports the function to ATen for torch compile
156+
"""
157+
# Trace function with input arguments and decompositions
158+
with torch.no_grad():
159+
fx_g, metadata, in_spec, out_spec = _aot_export_function(
160+
func,
161+
args,
162+
decompositions=decompositions,
163+
)
164+
165+
# No input mutations
166+
if (
167+
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
168+
!= 0
169+
):
170+
raise RuntimeError(
171+
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
172+
)
173+
# No pytrees
174+
if type(in_spec) == pytree.LeafSpec:
175+
raise RuntimeError(
176+
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
177+
)
178+
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
179+
raise RuntimeError(
180+
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
181+
)
182+
if type(out_spec) == pytree.LeafSpec:
183+
raise RuntimeError(
184+
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
185+
)
186+
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
187+
raise RuntimeError(
188+
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
189+
)
190+
191+
return fx_g

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from datetime import datetime
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

6-
import numpy
6+
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
99
import tensorrt as trt
1010
import torch
1111
import torch.fx
1212
from torch.fx.node import _get_qualified_name
1313
from torch.fx.passes.shape_prop import TensorMetadata
14+
from torch.utils._python_dispatch import _disable_current_modes
1415
from torch_tensorrt._Input import Input
1516
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
1617
from torch_tensorrt.fx.observer import Observer
@@ -169,7 +170,7 @@ def run(
169170

170171
cache = None
171172
if timing_cache:
172-
cache_file = numpy.array(timing_cache)
173+
cache_file = np.array(timing_cache)
173174
cache = builder_config.create_timing_cache(cache_file.tobytes())
174175
else:
175176
cache = builder_config.create_timing_cache(b"")
@@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
323324
assert self._cur_node_name is not None
324325
return converter(self.network, target, args, kwargs, self._cur_node_name)
325326

327+
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
328+
with _disable_current_modes():
329+
from torch_tensorrt.fx.converters import to_numpy
330+
331+
frozen_attr = self.fetch_attr(target)
332+
333+
if isinstance(frozen_attr, torch.nn.Parameter):
334+
constant_tensor = frozen_attr.data
335+
else:
336+
constant_tensor = frozen_attr
337+
338+
network_constant = to_numpy(constant_tensor)
339+
340+
return network_constant
341+
326342
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
327343
assert isinstance(target, str)
328344
converter = CONVERTERS.get(self._cur_node)
@@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
344360
else:
345361
outputs = (args[0],)
346362

363+
for output_idx in range(len(outputs)):
364+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
365+
366+
output = outputs[output_idx]
367+
368+
if not isinstance(output, trt.tensorrt.ITensor):
369+
new_output = get_trt_tensor(self.network, output, target)
370+
outputs = (
371+
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
372+
)
373+
347374
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
348375
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
349376

0 commit comments

Comments
 (0)