Skip to content

Commit ad22101

Browse files
authored
Update the fusion section of tuning_guide.py (#2889)
1 parent 2dd1997 commit ad22101

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

recipes_source/recipes/tuning_guide.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -94,35 +94,36 @@
9494
# ``optimizer.zero_grad(set_to_none=True)``.
9595

9696
###############################################################################
97-
# Fuse pointwise operations
97+
# Fuse operations
9898
# ~~~~~~~~~~~~~~~~~~~~~~~~~
99-
# Pointwise operations (elementwise addition, multiplication, math functions -
100-
# ``sin()``, ``cos()``, ``sigmoid()`` etc.) can be fused into a single kernel
101-
# to amortize memory access time and kernel launch time.
102-
#
103-
# `PyTorch JIT <https://pytorch.org/docs/stable/jit.html>`_ can fuse kernels
104-
# automatically, although there could be additional fusion opportunities not yet
105-
# implemented in the compiler, and not all device types are supported equally.
106-
#
107-
# Pointwise operations are memory-bound, for each operation PyTorch launches a
108-
# separate kernel. Each kernel loads data from the memory, performs computation
109-
# (this step is usually inexpensive) and stores results back into the memory.
110-
#
111-
# Fused operator launches only one kernel for multiple fused pointwise ops and
112-
# loads/stores data only once to the memory. This makes JIT very useful for
113-
# activation functions, optimizers, custom RNN cells etc.
99+
# Pointwise operations such as elementwise addition, multiplication, and math
100+
# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a
101+
# single kernel. This fusion helps reduce memory access and kernel launch times.
102+
# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates
103+
# a separate kernel for each operation, which involves loading data from memory,
104+
# executing the operation (often not the most time-consuming step), and writing
105+
# the results back to memory.
106+
#
107+
# By using a fused operator, only one kernel is launched for multiple pointwise
108+
# operations, and data is loaded and stored just once. This efficiency is
109+
# particularly beneficial for activation functions, optimizers, and custom RNN cells etc.
110+
#
111+
# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler
112+
# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple
113+
# element-wise operations, enabling advanced fusion of eligible pointwise and reduction
114+
# operations for optimized performance.
114115
#
115116
# In the simplest case fusion can be enabled by applying
116-
# `torch.jit.script <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_
117+
# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_
117118
# decorator to the function definition, for example:
118119

119-
@torch.jit.script
120-
def fused_gelu(x):
120+
@torch.compile
121+
def gelu(x):
121122
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
122123

123124
###############################################################################
124125
# Refer to
125-
# `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_
126+
# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
126127
# for more advanced use cases.
127128

128129
###############################################################################

0 commit comments

Comments
 (0)