Skip to content

Commit a208df5

Browse files
committed
feat: Add preliminary support for freezing tensors in Dynamo
1 parent 65277c5 commit a208df5

File tree

6 files changed

+249
-6
lines changed

6 files changed

+249
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import torch
2+
import torch.utils._pytree as pytree
3+
from torch import nn
4+
from typing import Callable, Optional, Dict
5+
from torch._functorch.aot_autograd import (
6+
AOT_COUNTER,
7+
create_functional_call,
8+
create_aot_dispatcher_function,
9+
AOTConfig,
10+
)
11+
from torch._subclasses import FakeTensor
12+
from torch._functorch.partitioners import default_partition
13+
14+
15+
def aot_module(
16+
mod: nn.Module,
17+
args,
18+
fw_compiler: Callable,
19+
partition_fn: Callable = default_partition,
20+
decompositions: Optional[Dict] = None,
21+
keep_inference_input_mutations=False,
22+
) -> nn.Module:
23+
"""
24+
Adapted from:
25+
https://github.com/pytorch/pytorch/blob/cce2b7e3c95a7505b41bdfc53939d84d56e31260/torch/_functorch/aot_autograd.py#L3656-L3776
26+
27+
This is the simplified or low overhead version of aot_module. For frontends
28+
like TorchDynamo, the input functions/modules to AOT are static and have
29+
unpacked inputs/outputs. This gives us an opportunity to remove the
30+
(1) pytree overhead to parse inputs/outputs,
31+
(2) AOT Autograd cache,
32+
(3) Reading of params/buffers in every forward call
33+
34+
35+
:func:`aot_module_simplified` removes these overheads.
36+
"""
37+
38+
params = {
39+
**dict(mod.named_parameters(remove_duplicate=False)),
40+
**dict(mod.named_buffers(remove_duplicate=False)),
41+
}
42+
params_flat, params_spec = pytree.tree_flatten(params)
43+
params_flat = list(params_flat)
44+
params_len = len(params_flat)
45+
46+
functional_call = create_functional_call(mod, params_spec, params_len)
47+
48+
seen_sources = set()
49+
50+
full_args = []
51+
# First, the params
52+
full_args.extend(params_flat)
53+
54+
if torch._guards.TracingContext.get():
55+
torch._guards.TracingContext.get().params_flat = params_flat
56+
57+
aot_autograd_arg_pos_to_source = None
58+
# Then, the params 1:1 mapped sources, if relevant.
59+
if hasattr(mod, "_param_name_to_source"):
60+
aot_autograd_arg_pos_to_source = []
61+
# We now know this came from dynamo, and (1) we care about guards,
62+
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
63+
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
64+
for name in params.keys():
65+
assert name in mod._param_name_to_source, f"{name} not found."
66+
source = mod._param_name_to_source[name]
67+
assert source not in seen_sources, source
68+
seen_sources.add(source)
69+
aot_autograd_arg_pos_to_source.append(source)
70+
71+
# Next, the input args
72+
full_args.extend(args)
73+
74+
if hasattr(mod, "graph"):
75+
# Non dynamo entrypoints can get to here...
76+
for i, node in enumerate(mod.graph.nodes):
77+
if node.op == "placeholder":
78+
if hasattr(node, "_dynamo_source"):
79+
# ... but not here!
80+
if aot_autograd_arg_pos_to_source is None:
81+
aot_autograd_arg_pos_to_source = []
82+
source = node._dynamo_source
83+
assert source not in seen_sources, source
84+
seen_sources.add(source)
85+
aot_autograd_arg_pos_to_source.append(source)
86+
87+
if aot_autograd_arg_pos_to_source is not None:
88+
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
89+
90+
dynamic_shapes = False
91+
for x in full_args:
92+
if isinstance(x, FakeTensor):
93+
dynamic_shapes = x.fake_mode.shape_env is not None
94+
break
95+
96+
aot_config = AOTConfig(
97+
fw_compiler=fw_compiler,
98+
bw_compiler=fw_compiler,
99+
inference_compiler=fw_compiler,
100+
partition_fn=partition_fn,
101+
decompositions=decompositions,
102+
num_params_buffers=params_len,
103+
aot_id=next(AOT_COUNTER),
104+
keep_inference_input_mutations=keep_inference_input_mutations,
105+
dynamic_shapes=dynamic_shapes,
106+
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
107+
is_export=False,
108+
no_tangents=False,
109+
)
110+
111+
compiled_fn = create_aot_dispatcher_function(
112+
functional_call,
113+
full_args,
114+
aot_config,
115+
)
116+
117+
def forward(*runtime_args):
118+
full_args = []
119+
full_args.extend(runtime_args)
120+
return compiled_fn(full_args)
121+
122+
# Just for convenience
123+
forward.zero_grad = mod.zero_grad
124+
forward.named_parameters = mod.named_parameters
125+
forward.named_buffers = mod.named_buffers
126+
127+
return forward

py/torch_tensorrt/dynamo/backend/backends.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from functools import partial
55
import torch._dynamo as td
6+
from torch._guards import TracingContext
67

78
from torch_tensorrt.dynamo import CompilationSettings
89
from torch_tensorrt.dynamo.lowering._decompositions import (
@@ -15,10 +16,12 @@
1516
partition,
1617
get_submod_inputs,
1718
)
19+
from torch_tensorrt.dynamo.lowering._freeze_aot_graph import freeze_autograd_gm
1820
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1921
from torch_tensorrt.dynamo.conversion import convert_module
2022

21-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
23+
from torch._functorch.aot_autograd import make_boxed_compiler
24+
from .aot_module import aot_module
2225

2326

2427
logger = logging.getLogger(__name__)
@@ -30,6 +33,8 @@ def torch_tensorrt_backend(
3033
):
3134
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3235

36+
TracingContext.get().fake_mode.allow_non_fake_inputs = True
37+
3338
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3439

3540

@@ -48,7 +53,7 @@ def aot_torch_tensorrt_aten_backend(
4853
gm = pre_aot_substitutions(gm)
4954

5055
# Invoke AOTAutograd to translate operators to aten
51-
return aot_module_simplified(
56+
return aot_module(
5257
gm,
5358
sample_inputs,
5459
fw_compiler=make_boxed_compiler(custom_backend),
@@ -73,9 +78,16 @@ def _pretraced_backend(
7378
try:
7479
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
7580

81+
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
82+
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
83+
84+
frozen_gm.graph.eliminate_dead_code()
85+
frozen_gm.graph.lint()
86+
frozen_gm.recompile()
87+
7688
trt_compiled = _compile_module(
77-
gm,
78-
sample_inputs,
89+
frozen_gm,
90+
nonfrozen_inputs,
7991
settings=settings,
8092
)
8193
return trt_compiled

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

+30
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
unified_dtype_converter,
2323
Frameworks,
2424
)
25+
from torch.utils._python_dispatch import _disable_current_modes
26+
2527

2628
_LOGGER: logging.Logger = logging.getLogger(__name__)
2729

@@ -296,6 +298,21 @@ def call_function(self, target, args, kwargs):
296298
assert self._cur_node_name is not None
297299
return converter(self.network, target, args, kwargs, self._cur_node_name)
298300

301+
def get_attr(self, target, args, kwargs):
302+
with _disable_current_modes():
303+
from torch_tensorrt.fx.converters import to_numpy
304+
305+
frozen_attr = self.fetch_attr(target)
306+
307+
if isinstance(frozen_attr, torch.nn.Parameter):
308+
constant_tensor = frozen_attr.data
309+
else:
310+
constant_tensor = frozen_attr
311+
312+
network_constant = to_numpy(constant_tensor)
313+
314+
return network_constant
315+
299316
def call_method(self, target, args, kwargs):
300317
assert isinstance(target, str)
301318
converter = CONVERTERS.get(target)
@@ -317,6 +334,17 @@ def output(self, target, args, kwargs):
317334
else:
318335
outputs = (args[0],)
319336

337+
for output_idx in range(len(outputs)):
338+
from torch_tensorrt.fx.converters import get_trt_tensor
339+
340+
output = outputs[output_idx]
341+
342+
if not isinstance(output, trt.tensorrt.ITensor):
343+
new_output = get_trt_tensor(self.network, output, target)
344+
outputs = (
345+
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
346+
)
347+
320348
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
321349
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
322350

@@ -356,3 +384,5 @@ def output(self, target, args, kwargs):
356384
elif self.output_fp16 and output.dtype == trt.float32:
357385
output.dtype = trt.float16
358386
self._output_names.append(name)
387+
388+
return list(outputs)

py/torch_tensorrt/dynamo/lowering/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
99
from .substitutions import *
1010
from ._fusers import *
11+
from ._freeze_aot_graph import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
from typing import List, Tuple
3+
from torch._inductor.freezing import replace_params_with_constants, constant_fold
4+
from torch._inductor.compile_fx import fake_tensor_prop
5+
from torch._functorch.compile_utils import fx_graph_cse
6+
import torch.fx.traceback as fx_traceback
7+
from torch._dynamo.utils import detect_fake_mode
8+
from torch.fx.experimental.proxy_tensor import make_fx
9+
from torch.fx.passes.tools_common import legalize_graph
10+
import unittest
11+
12+
13+
def freeze_autograd_gm(
14+
aot_autograd_gm: torch.fx.GraphModule,
15+
example_inputs: List[torch._subclasses.FakeTensor],
16+
) -> Tuple[torch.fx.GraphModule, List[int]]:
17+
"""
18+
Adapted from:
19+
https://github.com/pytorch/pytorch/blob/750b9b359f06cb8b8c2d5b6118bba636e2112cbb/torch/_inductor/freezing.py#L186-L243
20+
21+
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
22+
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
23+
24+
Assumes that this function is run in dynamo tracing post aot_autograd.
25+
26+
Args:
27+
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
28+
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
29+
30+
Returns:
31+
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
32+
of the inputs that were preserved (not turned into constants).
33+
"""
34+
# Extract necessary metadata and parameters
35+
fw_metadata = torch._guards.TracingContext.get().fw_metadata
36+
params_flat = torch._guards.TracingContext.get().params_flat
37+
assert fw_metadata is not None and params_flat is not None
38+
39+
# Replace placeholders with get_attr nodes
40+
preserved_arg_indices = replace_params_with_constants(
41+
aot_autograd_gm, params_flat, fw_metadata
42+
)
43+
44+
constant_fold(aot_autograd_gm)
45+
46+
fake_mode = detect_fake_mode(example_inputs)
47+
48+
# constant params will be real tensors, not fake
49+
# TODO: fake_mode should should enable py dispatcher if its symbolic ?
50+
with unittest.mock.patch.object(
51+
fake_mode, "allow_non_fake_inputs", True
52+
), fake_mode:
53+
args = [e for i, e in enumerate(example_inputs) if i in preserved_arg_indices]
54+
with fx_traceback.preserve_node_meta():
55+
aot_autograd_gm = make_fx(aot_autograd_gm, _allow_non_fake_inputs=True)(
56+
*args
57+
)
58+
59+
# TODO - further restrict cse ? right now needed to dedup aliasing ops
60+
cse_graph = fx_graph_cse(aot_autograd_gm.graph)
61+
aot_autograd_gm.graph = cse_graph
62+
aot_autograd_gm.recompile()
63+
64+
# Make sure meta['val'] is properly setup(weight conversion
65+
# or decompose_unfused_batchnorms lost meta['val']).
66+
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
67+
fake_tensor_prop(aot_autograd_gm, aot_example_inputs, True)
68+
69+
# TODO - apply legalization in pattern matcher
70+
legalize_graph(aot_autograd_gm)
71+
constant_fold(aot_autograd_gm)
72+
73+
return aot_autograd_gm, preserved_arg_indices

py/torch_tensorrt/dynamo/lowering/_partition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def is_node_supported(
125125

126126
if (
127127
node.target in CONVERTERS.keys()
128-
and node_name not in self.torch_executed_ops
129-
):
128+
or (node.op == "get_attr" and "frozen" in node_name)
129+
) and node_name not in self.torch_executed_ops:
130130
# If node is a proper, supported computational node, store the operator
131131
if not node.is_impure():
132132
self.supported_operators.add(node_name)

0 commit comments

Comments
 (0)