1
1
import logging
2
2
from typing import Sequence
3
3
import torch
4
+ from torch ._dynamo .utils import detect_fake_mode
5
+ import unittest
4
6
from functools import partial
5
7
import torch ._dynamo as td
6
8
from torch ._guards import TracingContext
23
25
repair_long_or_double_inputs ,
24
26
)
25
27
26
- from torch ._functorch .aot_autograd import make_boxed_compiler
27
- from .aot_module import aot_module
28
+ from torch ._functorch .aot_autograd import aot_export_joint_simple
28
29
29
30
30
31
logger = logging .getLogger (__name__ )
@@ -47,21 +48,25 @@ def aot_torch_tensorrt_aten_backend(
47
48
):
48
49
settings = parse_dynamo_kwargs (kwargs )
49
50
50
- custom_backend = partial (
51
- _pretraced_backend ,
52
- settings = settings ,
53
- )
54
-
55
51
# Perform Pre-AOT Lowering for Module-Level Replacement
56
52
gm = pre_aot_substitutions (gm )
57
53
58
- # Invoke AOTAutograd to translate operators to aten
59
- return aot_module (
60
- gm ,
61
- sample_inputs ,
62
- fw_compiler = make_boxed_compiler (custom_backend ),
63
- decompositions = get_decompositions (),
64
- )
54
+ fake_mode = detect_fake_mode (sample_inputs )
55
+
56
+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
57
+ with unittest .mock .patch .object (
58
+ fake_mode , "allow_non_fake_inputs" , True
59
+ ), fake_mode :
60
+
61
+ # Invoke AOTAutograd to translate operators to aten
62
+ graph_module = aot_export_joint_simple (
63
+ gm ,
64
+ sample_inputs ,
65
+ trace_joint = False ,
66
+ decompositions = get_decompositions (),
67
+ )
68
+
69
+ return _pretraced_backend (graph_module , sample_inputs , settings )
65
70
66
71
67
72
def _pretraced_backend (
@@ -81,16 +86,9 @@ def _pretraced_backend(
81
86
try :
82
87
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
83
88
84
- frozen_gm , unfrozen_indices = freeze_autograd_gm (gm , sample_inputs )
85
- nonfrozen_inputs = [sample_inputs [idx ] for idx in unfrozen_indices ]
86
-
87
- frozen_gm .graph .eliminate_dead_code ()
88
- frozen_gm .graph .lint ()
89
- frozen_gm .recompile ()
90
-
91
89
trt_compiled = _compile_module (
92
- frozen_gm ,
93
- nonfrozen_inputs ,
90
+ gm ,
91
+ sample_inputs ,
94
92
settings = settings ,
95
93
)
96
94
return trt_compiled
0 commit comments