@@ -83,20 +83,6 @@ def opt_foo2(x, y):
83
83
return a + b
84
84
print (opt_foo2 (t1 , t2 ))
85
85
86
- # When using the decorator approach, nested function calls within the decorated
87
- # function will also be compiled.
88
-
89
- def nested_function (x ):
90
- return torch .sin (x )
91
-
92
- @torch .compile
93
- def outer_function (x , y ):
94
- a = nested_function (x )
95
- b = torch .cos (y )
96
- return a + b
97
-
98
- print (outer_function (t1 , t2 ))
99
-
100
86
######################################################################
101
87
# We can also optimize ``torch.nn.Module`` instances.
102
88
@@ -114,8 +100,25 @@ def forward(self, x):
114
100
opt_mod = torch .compile (mod )
115
101
print (opt_mod (t ))
116
102
103
+ ######################################################################
104
+ # torch.compile and Nested Calls
105
+ # ------------------------------
106
+ # Nested function calls within the decorated function will also be compiled.
107
+
108
+ def nested_function (x ):
109
+ return torch .sin (x )
110
+
111
+ @torch .compile
112
+ def outer_function (x , y ):
113
+ a = nested_function (x )
114
+ b = torch .cos (y )
115
+ return a + b
116
+
117
+ print (outer_function (t1 , t2 ))
118
+
119
+ ######################################################################
117
120
# In the same fashion, when compiling a module all sub-modules and methods
118
- # within it are also compiled.
121
+ # within it, that are not in a skiplist, are also compiled.
119
122
120
123
class OuterModule (torch .nn .Module ):
121
124
def __init__ (self ):
@@ -133,12 +136,20 @@ def forward(self, x):
133
136
134
137
######################################################################
135
138
# We can also disable some functions from being compiled by using
136
- # `torch.compiler.disable`
139
+ # `torch.compiler.disable`. Suppose you want to disable the tracing on just
140
+ # the `complex_function` function, but want to continue the tracing back in
141
+ # `complex_conjugate`. In this case, you can use
142
+ # `torch.compiler.disable(recursive=False)` option. Otherwise, the default is
143
+ # `recursive=True`.
137
144
138
- @torch .compiler .disable
145
+ def complex_conjugate (z ):
146
+ return torch .conj (z )
147
+
148
+ @torch .compiler .disable (recursive = False )
139
149
def complex_function (real , imag ):
140
150
# Assuming this function cause problems in the compilation
141
- return torch .complex (real , imag )
151
+ z = torch .complex (real , imag )
152
+ return complex_conjugate (z )
142
153
143
154
def outer_function ():
144
155
real = torch .tensor ([2 , 3 ], dtype = torch .float32 )
@@ -159,25 +170,27 @@ def outer_function():
159
170
#
160
171
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
161
172
#
162
- # When you use ``torch.compile``, the compiler will try to recursively inline
163
- # and compile every function call inside the target function or module.
173
+ # When you use ``torch.compile``, the compiler will try to recursively compile
174
+ # every function call inside the target function or module inside the target
175
+ # function or module that is not in a skiplist (e.g. builtins, some functions in
176
+ # the torch.* namespace).
164
177
#
165
- # This includes:
166
- #
167
- # - **Nested function calls:** All functions called within the decorated or compiled function will also be compiled.
168
- #
169
- # - **Nested modules:** If a ``torch.nn.Module`` is compiled, all sub-modules and functions within the module are also compiled.
170
- #
171
178
# **Best Practices:**
172
179
#
173
- # 1. **Modular Testing:** Test individual functions and modules with ``torch.compile``
180
+ # 1. **Top-Level Compilation:** One approach is to compile at the highest level
181
+ # possible (i.e., when the top-level module is initialized/called) and
182
+ # selectively disable compilation when encountering excessive graph breaks or
183
+ # errors. If there are still many compile issues, compile individual
184
+ # subcomponents instead.
185
+ #
186
+ # 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
174
187
# before integrating them into larger models to isolate potential issues.
175
188
#
176
- # 2 . **Disable Compilation Selectively:** If certain functions or sub-modules
189
+ # 3 . **Disable Compilation Selectively:** If certain functions or sub-modules
177
190
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
178
191
# managers to recursively exclude them from compilation.
179
192
#
180
- # 3 . **Compile Leaf Functions First:** In complex models with multiple nested
193
+ # 4 . **Compile Leaf Functions First:** In complex models with multiple nested
181
194
# functions and modules, start by compiling the leaf functions or modules first.
182
195
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
183
196
0 commit comments