Skip to content

Commit 1858ca6

Browse files
committed
Reviewed torch.compile tutorial
1 parent d181199 commit 1858ca6

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

Diff for: intermediate_source/torch_compile_tutorial.py

+42-29
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,6 @@ def opt_foo2(x, y):
8383
return a + b
8484
print(opt_foo2(t1, t2))
8585

86-
# When using the decorator approach, nested function calls within the decorated
87-
# function will also be compiled.
88-
89-
def nested_function(x):
90-
return torch.sin(x)
91-
92-
@torch.compile
93-
def outer_function(x, y):
94-
a = nested_function(x)
95-
b = torch.cos(y)
96-
return a + b
97-
98-
print(outer_function(t1, t2))
99-
10086
######################################################################
10187
# We can also optimize ``torch.nn.Module`` instances.
10288

@@ -114,8 +100,25 @@ def forward(self, x):
114100
opt_mod = torch.compile(mod)
115101
print(opt_mod(t))
116102

103+
######################################################################
104+
# torch.compile and Nested Calls
105+
# ------------------------------
106+
# Nested function calls within the decorated function will also be compiled.
107+
108+
def nested_function(x):
109+
return torch.sin(x)
110+
111+
@torch.compile
112+
def outer_function(x, y):
113+
a = nested_function(x)
114+
b = torch.cos(y)
115+
return a + b
116+
117+
print(outer_function(t1, t2))
118+
119+
######################################################################
117120
# In the same fashion, when compiling a module all sub-modules and methods
118-
# within it are also compiled.
121+
# within it, that are not in a skiplist, are also compiled.
119122

120123
class OuterModule(torch.nn.Module):
121124
def __init__(self):
@@ -133,12 +136,20 @@ def forward(self, x):
133136

134137
######################################################################
135138
# We can also disable some functions from being compiled by using
136-
# `torch.compiler.disable`
139+
# `torch.compiler.disable`. Suppose you want to disable the tracing on just
140+
# the `complex_function` function, but want to continue the tracing back in
141+
# `complex_conjugate`. In this case, you can use
142+
# `torch.compiler.disable(recursive=False)` option. Otherwise, the default is
143+
# `recursive=True`.
137144

138-
@torch.compiler.disable
145+
def complex_conjugate(z):
146+
return torch.conj(z)
147+
148+
@torch.compiler.disable(recursive=False)
139149
def complex_function(real, imag):
140150
# Assuming this function cause problems in the compilation
141-
return torch.complex(real, imag)
151+
z = torch.complex(real, imag)
152+
return complex_conjugate(z)
142153

143154
def outer_function():
144155
real = torch.tensor([2, 3], dtype=torch.float32)
@@ -159,25 +170,27 @@ def outer_function():
159170
#
160171
# Behavior of ``torch.compile`` with Nested Modules and Function Calls
161172
#
162-
# When you use ``torch.compile``, the compiler will try to recursively inline
163-
# and compile every function call inside the target function or module.
173+
# When you use ``torch.compile``, the compiler will try to recursively compile
174+
# every function call inside the target function or module inside the target
175+
# function or module that is not in a skiplist (e.g. builtins, some functions in
176+
# the torch.* namespace).
164177
#
165-
# This includes:
166-
#
167-
# - **Nested function calls:** All functions called within the decorated or compiled function will also be compiled.
168-
#
169-
# - **Nested modules:** If a ``torch.nn.Module`` is compiled, all sub-modules and functions within the module are also compiled.
170-
#
171178
# **Best Practices:**
172179
#
173-
# 1. **Modular Testing:** Test individual functions and modules with ``torch.compile``
180+
# 1. **Top-Level Compilation:** One approach is to compile at the highest level
181+
# possible (i.e., when the top-level module is initialized/called) and
182+
# selectively disable compilation when encountering excessive graph breaks or
183+
# errors. If there are still many compile issues, compile individual
184+
# subcomponents instead.
185+
#
186+
# 2. **Modular Testing:** Test individual functions and modules with ``torch.compile``
174187
# before integrating them into larger models to isolate potential issues.
175188
#
176-
# 2. **Disable Compilation Selectively:** If certain functions or sub-modules
189+
# 3. **Disable Compilation Selectively:** If certain functions or sub-modules
177190
# cannot be handled by `torch.compile`, use the `torch.compiler.disable` context
178191
# managers to recursively exclude them from compilation.
179192
#
180-
# 3. **Compile Leaf Functions First:** In complex models with multiple nested
193+
# 4. **Compile Leaf Functions First:** In complex models with multiple nested
181194
# functions and modules, start by compiling the leaf functions or modules first.
182195
# For more information see `TorchDynamo APIs for fine-grained tracing <https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html>`__.
183196

0 commit comments

Comments
 (0)