Skip to content

Commit 88ba994

Browse files
committed
fix: Refactor tensor freezing in Dynamo
1 parent beecf35 commit 88ba994

File tree

4 files changed

+21
-224
lines changed

4 files changed

+21
-224
lines changed

py/torch_tensorrt/dynamo/backend/aot_module.py

-127
This file was deleted.

py/torch_tensorrt/dynamo/backend/backends.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
from typing import Sequence
33
import torch
4+
from torch._dynamo.utils import detect_fake_mode
5+
import unittest
46
from functools import partial
57
import torch._dynamo as td
68
from torch._guards import TracingContext
@@ -23,8 +25,7 @@
2325
repair_long_or_double_inputs,
2426
)
2527

26-
from torch._functorch.aot_autograd import make_boxed_compiler
27-
from .aot_module import aot_module
28+
from torch._functorch.aot_autograd import aot_export_joint_simple
2829

2930

3031
logger = logging.getLogger(__name__)
@@ -47,21 +48,25 @@ def aot_torch_tensorrt_aten_backend(
4748
):
4849
settings = parse_dynamo_kwargs(kwargs)
4950

50-
custom_backend = partial(
51-
_pretraced_backend,
52-
settings=settings,
53-
)
54-
5551
# Perform Pre-AOT Lowering for Module-Level Replacement
5652
gm = pre_aot_substitutions(gm)
5753

58-
# Invoke AOTAutograd to translate operators to aten
59-
return aot_module(
60-
gm,
61-
sample_inputs,
62-
fw_compiler=make_boxed_compiler(custom_backend),
63-
decompositions=get_decompositions(),
64-
)
54+
fake_mode = detect_fake_mode(sample_inputs)
55+
56+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
57+
with unittest.mock.patch.object(
58+
fake_mode, "allow_non_fake_inputs", True
59+
), fake_mode:
60+
61+
# Invoke AOTAutograd to translate operators to aten
62+
graph_module = aot_export_joint_simple(
63+
gm,
64+
sample_inputs,
65+
trace_joint=False,
66+
decompositions=get_decompositions(),
67+
)
68+
69+
return _pretraced_backend(graph_module, sample_inputs, settings)
6570

6671

6772
def _pretraced_backend(
@@ -81,16 +86,9 @@ def _pretraced_backend(
8186
try:
8287
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8388

84-
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
85-
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
86-
87-
frozen_gm.graph.eliminate_dead_code()
88-
frozen_gm.graph.lint()
89-
frozen_gm.recompile()
90-
9189
trt_compiled = _compile_module(
92-
frozen_gm,
93-
nonfrozen_inputs,
90+
gm,
91+
sample_inputs,
9492
settings=settings,
9593
)
9694
return trt_compiled

py/torch_tensorrt/dynamo/lowering/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
99
from .substitutions import *
1010
from ._fusers import *
11-
from ._freeze_aot_graph import *

py/torch_tensorrt/dynamo/lowering/_freeze_aot_graph.py

-73
This file was deleted.

0 commit comments

Comments
 (0)