|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -import copy |
4 | 3 | import logging
|
5 |
| -import sys |
6 |
| -from contextlib import contextmanager |
7 |
| -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
| 4 | +import unittest.mock |
| 5 | +from typing import Any, Tuple |
8 | 6 |
|
9 | 7 | import torch
|
10 |
| -import torch._dynamo as torchdynamo |
11 |
| -from torch.fx.passes.infra.pass_base import PassResult |
12 |
| -from torch_tensorrt.dynamo.utils import req_torch_version |
13 |
| -from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( |
14 |
| - compose_bmm, |
15 |
| - compose_chunk, |
16 |
| - compose_getitem_slice, |
17 |
| - remove_ops, |
18 |
| - replace_aten_op_with_indices, |
19 |
| - replace_aten_reshape_alias_with_replace, |
20 |
| - replace_builtin_ops, |
21 |
| - replace_inplace_ops, |
22 |
| - replace_native_layernorm_with_layernorm, |
23 |
| - replace_transpose_mm_op_with_linear, |
24 |
| - run_const_fold, |
25 |
| -) |
26 |
| -from typing_extensions import TypeAlias |
27 |
| - |
28 |
| -Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]] |
| 8 | +from torch._export import export |
| 9 | +from torch_tensorrt.dynamo.backend.backends import constant_fold |
| 10 | +from torch_tensorrt.dynamo.lowering import get_decompositions |
| 11 | +from torch_tensorrt.dynamo.utils import set_log_level |
29 | 12 |
|
30 | 13 | logger = logging.getLogger(__name__)
|
31 | 14 |
|
32 | 15 |
|
33 |
| -class DynamoConfig: |
34 |
| - """ |
35 |
| - Manage Exir-specific configurations of Dynamo. |
36 |
| - """ |
37 |
| - |
38 |
| - def __init__( |
39 |
| - self, |
40 |
| - capture_scalar_outputs: bool = True, |
41 |
| - guard_nn_modules: bool = True, |
42 |
| - dynamic_shapes: bool = True, |
43 |
| - specialize_int: bool = True, |
44 |
| - verbose: bool = True, |
45 |
| - ) -> None: |
46 |
| - self.capture_scalar_outputs = capture_scalar_outputs |
47 |
| - self.guard_nn_modules = guard_nn_modules |
48 |
| - self.dynamic_shapes = dynamic_shapes |
49 |
| - self.specialize_int = specialize_int |
50 |
| - self.verbose = verbose |
51 |
| - |
52 |
| - def activate(self) -> None: |
53 |
| - torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs |
54 |
| - torchdynamo.config.guard_nn_modules = self.guard_nn_modules |
55 |
| - torchdynamo.config.dynamic_shapes = self.dynamic_shapes |
56 |
| - torchdynamo.config.specialize_int = self.specialize_int |
57 |
| - torchdynamo.config.verbose = self.verbose |
58 |
| - |
59 |
| - def deactivate(self) -> None: |
60 |
| - torchdynamo.config.capture_scalar_outputs = True |
61 |
| - torchdynamo.config.guard_nn_modules = True |
62 |
| - torchdynamo.config.dynamic_shapes = True |
63 |
| - torchdynamo.config.specialize_int = True |
64 |
| - torchdynamo.config.verbose = True |
65 |
| - |
66 |
| - |
67 |
| -@contextmanager |
68 |
| -def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]: |
69 |
| - config.activate() |
70 |
| - try: |
71 |
| - yield config |
72 |
| - finally: |
73 |
| - config.deactivate() |
74 |
| - |
75 |
| - |
76 |
| -@contextmanager |
77 |
| -def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: |
78 |
| - """ |
79 |
| - Temporarily increase the python interpreter stack recursion limit. |
80 |
| - This is mostly used for pickling large scale modules. |
81 |
| - """ |
82 |
| - default = sys.getrecursionlimit() |
83 |
| - if limit > default: |
84 |
| - sys.setrecursionlimit(limit) |
85 |
| - try: |
86 |
| - yield |
87 |
| - finally: |
88 |
| - sys.setrecursionlimit(default) |
89 |
| - |
90 |
| - |
91 |
| -@req_torch_version("2.dev") |
92 |
| -def dynamo_trace( |
93 |
| - f: Callable[..., Value], |
94 |
| - # pyre-ignore |
95 |
| - args: Tuple[Any, ...], |
96 |
| - aten_graph: bool, |
97 |
| - tracing_mode: str = "real", |
98 |
| - dynamo_config: Optional[DynamoConfig] = None, |
99 |
| -) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]: |
100 |
| - """ |
101 |
| - TODO: Once we fully migrate to torchdynamo frontend, we will remove |
102 |
| - this config option alltogether. For now, it helps with quick |
103 |
| - experiments with playing around with TorchDynamo |
104 |
| - """ |
105 |
| - if dynamo_config is None: |
106 |
| - dynamo_config = DynamoConfig() |
107 |
| - with using_config(dynamo_config), setting_python_recursive_limit(2000): |
108 |
| - torchdynamo.reset() |
109 |
| - try: |
110 |
| - return torchdynamo.export( |
111 |
| - f, |
112 |
| - *copy.deepcopy(args), |
113 |
| - aten_graph=aten_graph, |
114 |
| - tracing_mode=tracing_mode, |
115 |
| - ) |
116 |
| - except torchdynamo.exc.Unsupported as exc: |
117 |
| - raise RuntimeError( |
118 |
| - "The user code is using a feature we don't support. " |
119 |
| - "Please try torchdynamo.explain() to get possible the reasons", |
120 |
| - ) from exc |
121 |
| - except Exception as exc: |
122 |
| - raise RuntimeError( |
123 |
| - "torchdynamo internal error occured. Please see above stacktrace" |
124 |
| - ) from exc |
125 |
| - |
126 |
| - |
127 |
| -@req_torch_version("2.dev") |
128 | 16 | def trace(
|
129 | 17 | model: torch.nn.Module | torch.fx.GraphModule,
|
130 | 18 | inputs: Tuple[Any, ...],
|
131 | 19 | **kwargs: Any,
|
132 | 20 | ) -> torch.fx.GraphModule:
|
133 |
| - """ |
134 |
| - Optimized trace with necessary passes which re-compose some ops or replace some ops |
135 |
| - These passes should be general and functional purpose |
136 |
| - """ |
137 |
| - passes_list = [ |
138 |
| - compose_bmm, |
139 |
| - compose_chunk, |
140 |
| - compose_getitem_slice, |
141 |
| - replace_aten_reshape_alias_with_replace, |
142 |
| - replace_aten_op_with_indices, |
143 |
| - replace_transpose_mm_op_with_linear, # after compose_bmm |
144 |
| - replace_native_layernorm_with_layernorm, |
145 |
| - remove_ops, |
146 |
| - replace_builtin_ops, # after replace_native_layernorm_with_layernorm |
147 |
| - replace_inplace_ops, # remove it once functionalization is enabled |
148 |
| - ] |
149 |
| - |
150 |
| - fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic") |
151 |
| - |
152 |
| - for passes in passes_list: |
153 |
| - pr: PassResult = passes(fx_module) |
154 |
| - fx_module = pr.graph_module |
155 |
| - |
156 |
| - fx_module(*inputs) |
157 |
| - |
158 |
| - fx_module = run_const_fold(fx_module) |
159 |
| - logger.info("Post export graph : %s\n", fx_module.graph) |
160 |
| - return fx_module |
| 21 | + # Set log level at the top of compilation (torch_tensorrt.dynamo) |
| 22 | + if "debug" in kwargs and kwargs["debug"]: |
| 23 | + set_log_level(logger.parent, logging.DEBUG) |
| 24 | + |
| 25 | + experimental_decompositions = kwargs.get( |
| 26 | + "enable_experimental_decompositions", False |
| 27 | + ) |
| 28 | + with unittest.mock.patch( |
| 29 | + "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) |
| 30 | + ): |
| 31 | + graph_module = export(model, tuple(inputs)).module() |
| 32 | + constant_fold(graph_module) |
| 33 | + logger.debug("Post export graph: " + str(graph_module.graph)) |
| 34 | + return graph_module |
0 commit comments