Skip to content

Commit c1f130a

Browse files
peri044gs-olive
andauthored
feat: Transition export workflows to use torch._export APIs (#2195)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent 43eb4bb commit c1f130a

File tree

6 files changed

+93
-194
lines changed

6 files changed

+93
-194
lines changed
+20-146
Original file line numberDiff line numberDiff line change
@@ -1,160 +1,34 @@
11
from __future__ import annotations
22

3-
import copy
43
import logging
5-
import sys
6-
from contextlib import contextmanager
7-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
4+
import unittest.mock
5+
from typing import Any, Tuple
86

97
import torch
10-
import torch._dynamo as torchdynamo
11-
from torch.fx.passes.infra.pass_base import PassResult
12-
from torch_tensorrt.dynamo.utils import req_torch_version
13-
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
14-
compose_bmm,
15-
compose_chunk,
16-
compose_getitem_slice,
17-
remove_ops,
18-
replace_aten_op_with_indices,
19-
replace_aten_reshape_alias_with_replace,
20-
replace_builtin_ops,
21-
replace_inplace_ops,
22-
replace_native_layernorm_with_layernorm,
23-
replace_transpose_mm_op_with_linear,
24-
run_const_fold,
25-
)
26-
from typing_extensions import TypeAlias
27-
28-
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]
8+
from torch._export import export
9+
from torch_tensorrt.dynamo.backend.backends import constant_fold
10+
from torch_tensorrt.dynamo.lowering import get_decompositions
11+
from torch_tensorrt.dynamo.utils import set_log_level
2912

3013
logger = logging.getLogger(__name__)
3114

3215

33-
class DynamoConfig:
34-
"""
35-
Manage Exir-specific configurations of Dynamo.
36-
"""
37-
38-
def __init__(
39-
self,
40-
capture_scalar_outputs: bool = True,
41-
guard_nn_modules: bool = True,
42-
dynamic_shapes: bool = True,
43-
specialize_int: bool = True,
44-
verbose: bool = True,
45-
) -> None:
46-
self.capture_scalar_outputs = capture_scalar_outputs
47-
self.guard_nn_modules = guard_nn_modules
48-
self.dynamic_shapes = dynamic_shapes
49-
self.specialize_int = specialize_int
50-
self.verbose = verbose
51-
52-
def activate(self) -> None:
53-
torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs
54-
torchdynamo.config.guard_nn_modules = self.guard_nn_modules
55-
torchdynamo.config.dynamic_shapes = self.dynamic_shapes
56-
torchdynamo.config.specialize_int = self.specialize_int
57-
torchdynamo.config.verbose = self.verbose
58-
59-
def deactivate(self) -> None:
60-
torchdynamo.config.capture_scalar_outputs = True
61-
torchdynamo.config.guard_nn_modules = True
62-
torchdynamo.config.dynamic_shapes = True
63-
torchdynamo.config.specialize_int = True
64-
torchdynamo.config.verbose = True
65-
66-
67-
@contextmanager
68-
def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]:
69-
config.activate()
70-
try:
71-
yield config
72-
finally:
73-
config.deactivate()
74-
75-
76-
@contextmanager
77-
def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]:
78-
"""
79-
Temporarily increase the python interpreter stack recursion limit.
80-
This is mostly used for pickling large scale modules.
81-
"""
82-
default = sys.getrecursionlimit()
83-
if limit > default:
84-
sys.setrecursionlimit(limit)
85-
try:
86-
yield
87-
finally:
88-
sys.setrecursionlimit(default)
89-
90-
91-
@req_torch_version("2.dev")
92-
def dynamo_trace(
93-
f: Callable[..., Value],
94-
# pyre-ignore
95-
args: Tuple[Any, ...],
96-
aten_graph: bool,
97-
tracing_mode: str = "real",
98-
dynamo_config: Optional[DynamoConfig] = None,
99-
) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
100-
"""
101-
TODO: Once we fully migrate to torchdynamo frontend, we will remove
102-
this config option alltogether. For now, it helps with quick
103-
experiments with playing around with TorchDynamo
104-
"""
105-
if dynamo_config is None:
106-
dynamo_config = DynamoConfig()
107-
with using_config(dynamo_config), setting_python_recursive_limit(2000):
108-
torchdynamo.reset()
109-
try:
110-
return torchdynamo.export(
111-
f,
112-
*copy.deepcopy(args),
113-
aten_graph=aten_graph,
114-
tracing_mode=tracing_mode,
115-
)
116-
except torchdynamo.exc.Unsupported as exc:
117-
raise RuntimeError(
118-
"The user code is using a feature we don't support. "
119-
"Please try torchdynamo.explain() to get possible the reasons",
120-
) from exc
121-
except Exception as exc:
122-
raise RuntimeError(
123-
"torchdynamo internal error occured. Please see above stacktrace"
124-
) from exc
125-
126-
127-
@req_torch_version("2.dev")
12816
def trace(
12917
model: torch.nn.Module | torch.fx.GraphModule,
13018
inputs: Tuple[Any, ...],
13119
**kwargs: Any,
13220
) -> torch.fx.GraphModule:
133-
"""
134-
Optimized trace with necessary passes which re-compose some ops or replace some ops
135-
These passes should be general and functional purpose
136-
"""
137-
passes_list = [
138-
compose_bmm,
139-
compose_chunk,
140-
compose_getitem_slice,
141-
replace_aten_reshape_alias_with_replace,
142-
replace_aten_op_with_indices,
143-
replace_transpose_mm_op_with_linear, # after compose_bmm
144-
replace_native_layernorm_with_layernorm,
145-
remove_ops,
146-
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
147-
replace_inplace_ops, # remove it once functionalization is enabled
148-
]
149-
150-
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
151-
152-
for passes in passes_list:
153-
pr: PassResult = passes(fx_module)
154-
fx_module = pr.graph_module
155-
156-
fx_module(*inputs)
157-
158-
fx_module = run_const_fold(fx_module)
159-
logger.info("Post export graph : %s\n", fx_module.graph)
160-
return fx_module
21+
# Set log level at the top of compilation (torch_tensorrt.dynamo)
22+
if "debug" in kwargs and kwargs["debug"]:
23+
set_log_level(logger.parent, logging.DEBUG)
24+
25+
experimental_decompositions = kwargs.get(
26+
"enable_experimental_decompositions", False
27+
)
28+
with unittest.mock.patch(
29+
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
30+
):
31+
graph_module = export(model, tuple(inputs)).module()
32+
constant_fold(graph_module)
33+
logger.debug("Post export graph: " + str(graph_module.graph))
34+
return graph_module

py/torch_tensorrt/dynamo/backend/backends.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
import torch
88
import torch._dynamo as td
99
import torch.utils._pytree as pytree
10-
import torch_tensorrt
1110
from torch._dynamo.utils import detect_fake_mode
1211
from torch._functorch.aot_autograd import _aot_export_function
1312
from torch._ops import OpOverload
13+
from torch_tensorrt._utils import sanitized_torch_version
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
1616
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
1717
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
18-
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
18+
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level
1919

2020
from packaging import version
2121

2222
# Modify import location of utilities based on Torch version
23-
if version.parse(torch_tensorrt.sanitized_torch_version()) < version.parse("2.1.1"):
23+
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
2424
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
2525
else:
2626
from torch._inductor.constant_folding import (
@@ -38,14 +38,11 @@ def torch_tensorrt_backend(
3838
) -> torch.nn.Module:
3939
# Set log level at the top of compilation (torch_tensorrt.dynamo)
4040
if (
41-
(
42-
"options" in kwargs
43-
and "debug" in kwargs["options"]
44-
and kwargs["options"]["debug"]
45-
)
46-
or ("debug" in kwargs and kwargs["debug"])
47-
) and logger.parent:
48-
logger.parent.setLevel(logging.DEBUG)
41+
"options" in kwargs
42+
and "debug" in kwargs["options"]
43+
and kwargs["options"]["debug"]
44+
) or ("debug" in kwargs and kwargs["debug"]):
45+
set_log_level(logger.parent, logging.DEBUG)
4946

5047
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
5148

py/torch_tensorrt/dynamo/compile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from torch_tensorrt.dynamo.utils import (
3535
prepare_inputs,
36+
set_log_level,
3637
to_torch_device,
3738
to_torch_tensorrt_device,
3839
)
@@ -72,8 +73,7 @@ def compile(
7273
**kwargs: Any,
7374
) -> torch.fx.GraphModule:
7475
if debug:
75-
if logger.parent:
76-
logger.parent.setLevel(logging.DEBUG)
76+
set_log_level(logger.parent, logging.DEBUG)
7777

7878
enabled_precisions = set(enabled_precisions)
7979

py/torch_tensorrt/dynamo/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ def cosine_similarity(gt_tensor: torch.Tensor, pred_tensor: torch.Tensor) -> flo
6363
return res
6464

6565

66+
def set_log_level(parent_logger: Any, level: Any) -> None:
67+
"""
68+
Sets the log level to the user provided level.
69+
This is used to set debug logging at a global level
70+
at entry points of tracing, dynamo and torch_compile compilation.
71+
"""
72+
if parent_logger:
73+
parent_logger.setLevel(level)
74+
75+
6676
def prepare_inputs(
6777
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
6878
device: torch.device = torch.device("cuda"),

tests/py/dynamo/models/test_models.py

-15
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ def test_resnet18(ir):
4040
# Clean up model env
4141
torch._dynamo.reset()
4242

43-
with torch.no_grad():
44-
torch.cuda.empty_cache()
45-
4643

4744
@pytest.mark.unit
4845
def test_mobilenet_v2(ir):
@@ -74,9 +71,6 @@ def test_mobilenet_v2(ir):
7471
# Clean up model env
7572
torch._dynamo.reset()
7673

77-
with torch.no_grad():
78-
torch.cuda.empty_cache()
79-
8074

8175
@pytest.mark.unit
8276
def test_efficientnet_b0(ir):
@@ -108,9 +102,6 @@ def test_efficientnet_b0(ir):
108102
# Clean up model env
109103
torch._dynamo.reset()
110104

111-
with torch.no_grad():
112-
torch.cuda.empty_cache()
113-
114105

115106
@pytest.mark.unit
116107
def test_bert_base_uncased(ir):
@@ -155,9 +146,6 @@ def test_bert_base_uncased(ir):
155146
# Clean up model env
156147
torch._dynamo.reset()
157148

158-
with torch.no_grad():
159-
torch.cuda.empty_cache()
160-
161149

162150
@pytest.mark.unit
163151
def test_resnet18_half(ir):
@@ -187,6 +175,3 @@ def test_resnet18_half(ir):
187175

188176
# Clean up model env
189177
torch._dynamo.reset()
190-
191-
with torch.no_grad():
192-
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)