|
94 | 94 | # ``optimizer.zero_grad(set_to_none=True)``.
|
95 | 95 |
|
96 | 96 | ###############################################################################
|
97 |
| -# Fuse pointwise operations |
| 97 | +# Fuse operations |
98 | 98 | # ~~~~~~~~~~~~~~~~~~~~~~~~~
|
99 |
| -# Pointwise operations (elementwise addition, multiplication, math functions - |
100 |
| -# ``sin()``, ``cos()``, ``sigmoid()`` etc.) can be fused into a single kernel |
101 |
| -# to amortize memory access time and kernel launch time. |
102 |
| -# |
103 |
| -# `PyTorch JIT <https://pytorch.org/docs/stable/jit.html>`_ can fuse kernels |
104 |
| -# automatically, although there could be additional fusion opportunities not yet |
105 |
| -# implemented in the compiler, and not all device types are supported equally. |
106 |
| -# |
107 |
| -# Pointwise operations are memory-bound, for each operation PyTorch launches a |
108 |
| -# separate kernel. Each kernel loads data from the memory, performs computation |
109 |
| -# (this step is usually inexpensive) and stores results back into the memory. |
110 |
| -# |
111 |
| -# Fused operator launches only one kernel for multiple fused pointwise ops and |
112 |
| -# loads/stores data only once to the memory. This makes JIT very useful for |
113 |
| -# activation functions, optimizers, custom RNN cells etc. |
| 99 | +# Pointwise operations such as elementwise addition, multiplication, and math |
| 100 | +# functions like `sin()`, `cos()`, `sigmoid()`, etc., can be combined into a |
| 101 | +# single kernel. This fusion helps reduce memory access and kernel launch times. |
| 102 | +# Typically, pointwise operations are memory-bound; PyTorch eager-mode initiates |
| 103 | +# a separate kernel for each operation, which involves loading data from memory, |
| 104 | +# executing the operation (often not the most time-consuming step), and writing |
| 105 | +# the results back to memory. |
| 106 | +# |
| 107 | +# By using a fused operator, only one kernel is launched for multiple pointwise |
| 108 | +# operations, and data is loaded and stored just once. This efficiency is |
| 109 | +# particularly beneficial for activation functions, optimizers, and custom RNN cells etc. |
| 110 | +# |
| 111 | +# PyTorch 2 introduces a compile-mode facilitated by TorchInductor, an underlying compiler |
| 112 | +# that automatically fuses kernels. TorchInductor extends its capabilities beyond simple |
| 113 | +# element-wise operations, enabling advanced fusion of eligible pointwise and reduction |
| 114 | +# operations for optimized performance. |
114 | 115 | #
|
115 | 116 | # In the simplest case fusion can be enabled by applying
|
116 |
| -# `torch.jit.script <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ |
| 117 | +# `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_ |
117 | 118 | # decorator to the function definition, for example:
|
118 | 119 |
|
119 |
| -@torch.jit.script |
120 |
| -def fused_gelu(x): |
| 120 | +@torch.compile |
| 121 | +def gelu(x): |
121 | 122 | return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
|
122 | 123 |
|
123 | 124 | ###############################################################################
|
124 | 125 | # Refer to
|
125 |
| -# `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_ |
| 126 | +# `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
126 | 127 | # for more advanced use cases.
|
127 | 128 |
|
128 | 129 | ###############################################################################
|
|
0 commit comments