Skip to content

Commit 86d3fe1

Browse files
committed
First commit
1 parent ea4e155 commit 86d3fe1

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

Diff for: recipes_source/foreach_map.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""
2+
(beta) Explicit horizontal fusion with foreach_map and torch.compile
3+
============================================================
4+
5+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6+
"""
7+
8+
#########################################################
9+
# Horizontal fusion is a key optimization in ML compilers. In eager,
10+
# this is typically expressed using the torch._foreach* ops which parallelizes
11+
# operations across a list of tensors. However, supporting all possible permutations
12+
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
13+
# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach
14+
# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer
15+
# with ``foreach_map`` to generate a fully fused kernel.
16+
#
17+
#
18+
# .. note::
19+
#
20+
# This tutorial requires PyTorch 2.7.0 or later.
21+
22+
#####################################################################
23+
# Model Setup
24+
# ~~~~~~~~~~~~~~~~~~~~~
25+
# For this example, we'll use a simple sequence of linear layers.
26+
# We instantiate an independent copy to compare the two optimizer implementations.
27+
#
28+
import torch
29+
30+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
31+
if torch.cuda.get_device_capability() < (7, 0):
32+
print("Exiting because torch.compile is not supported on this device.")
33+
import sys
34+
sys.exit(0)
35+
36+
# Create simple model
37+
model = torch.nn.Sequential(
38+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
39+
)
40+
model_copy = torch.nn.Sequential(
41+
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
42+
)
43+
input = torch.rand(1024, device="cuda")
44+
45+
# run forward pass
46+
output = model(input)
47+
output_copy = model_copy(input)
48+
49+
# run backward to populate the grads for our optimizer below
50+
output.sum().backward()
51+
output_copy.sum().backward()
52+
53+
#####################################################################
54+
# Helper functions for foreach_map implementation
55+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
56+
#
57+
# In this section, we'll begin our implementation of the Adam optimizer.
58+
#
59+
from torch._higher_order_ops.foreach_map import foreach_map
60+
61+
# Helper function to extract optimizer states from a torch.optim.Adam instance
62+
def get_inputs(optim):
63+
steps = []
64+
params = []
65+
grads = []
66+
exp_avgs = []
67+
exp_avg_sqs = []
68+
for group in optim.param_groups:
69+
for p in group["params"]:
70+
params.append(p)
71+
grads.append(p.grad)
72+
state = optim.state[p]
73+
exp_avgs.append(state["exp_avg"])
74+
exp_avg_sqs.append(state["exp_avg_sq"])
75+
steps.append(state["step"])
76+
77+
return steps, params, exp_avgs, exp_avg_sqs
78+
79+
80+
# Functions to update the different optimizer states
81+
def update_exp_avg_sq(exp_avg_sq, grad, beta2):
82+
return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)
83+
84+
def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
85+
bias_correction1 = 1 - torch.pow(beta1, step)
86+
bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()
87+
step_size = (lr / bias_correction1).neg()
88+
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
89+
return torch.add(param, torch.div(exp_avg, denom))
90+
91+
# Our full Adam implementation
92+
def foreach_map_adam(
93+
steps,
94+
params,
95+
exp_avgs,
96+
exp_avg_sqs,
97+
weight_decay=0,
98+
beta1=0.9,
99+
beta2=0.999,
100+
lr=1e-3,
101+
eps=1e-8,
102+
):
103+
with torch.no_grad():
104+
grads = [param.grad for param in params]
105+
# update step
106+
updated_steps = foreach_map(lambda x: x + 1, steps)
107+
torch._foreach_copy_(steps, updated_steps)
108+
109+
if weight_decay != 0:
110+
foreach_map(torch.add, (grads,), alpha=weight_decay)
111+
112+
# Higher-order operators (HOPs) cannot have multiple outputs at the moment
113+
# need to call foreach_map once for each output
114+
exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)
115+
exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)
116+
params_updated = foreach_map(
117+
update_param,
118+
params,
119+
steps,
120+
exp_avgs_updated,
121+
exp_avgs_sq_updated,
122+
beta1,
123+
beta2,
124+
lr,
125+
eps,
126+
)
127+
# Higher-order operators (HOPs) don't support input mutation today
128+
# so manually update the states in-place
129+
torch._foreach_copy_(exp_avgs, exp_avgs_updated)
130+
torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)
131+
torch._foreach_copy_(params, params_updated)
132+
return
133+
134+
#####################################################################
135+
# Setting up and running the compiled kernel
136+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137+
#
138+
# In this section, we'll run our Adam optimizer
139+
# and compare the results
140+
#
141+
# .. note::
142+
#
143+
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
144+
opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
145+
opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))
146+
147+
# warm up the optimizer state dict
148+
opt_eager.step()
149+
opt_eager_copy.step()
150+
151+
inputs = get_inputs(opt_eager_copy)
152+
compiled_adam = torch.compile(foreach_map_adam)
153+
154+
# optionally view the output code
155+
torch._logging.set_logs(output_code=True)
156+
157+
# Warmup runs to compile the function
158+
for _ in range(5):
159+
opt_eager.step()
160+
compiled_adam(*inputs)
161+
162+
for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]):
163+
torch.allclose(eager_p, compile_p)
164+
165+
# Benchmark performance
166+
167+
# Let's define a helpful benchmarking function:
168+
import torch.utils.benchmark as benchmark
169+
170+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
171+
t0 = benchmark.Timer(
172+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
173+
)
174+
return t0.blocked_autorange().mean * 1e6
175+
176+
eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)
177+
compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))
178+
179+
assert eager_runtime > compiled_runtime
180+
181+
print(f"eager runtime: {eager_runtime}us")
182+
print(f"compiled runtime: {compiled_runtime}us")
183+
184+
185+
186+
######################################################################
187+
# Conclusion
188+
# ~~~~~~~~~~
189+
# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.
190+
# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam
191+
# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide
192+
# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a
193+
# valuable resource for developers looking to improve the performance of their models with horizontal fusion.
194+
#
195+
# See also:
196+
#
197+
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
198+
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.

Diff for: recipes_source/recipes_index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
335335
:link: ../recipes/compiling_optimizer_lr_scheduler.html
336336
:tags: Model-Optimization
337337

338+
.. (beta) Explicit horizontal fusion with foreach_map and torch.compile
339+
340+
.. customcarditem::
341+
:header: (beta) Explicit horizontal fusion with foreach_map and torch.compile
342+
:card_description: Horizontally fuse pointwise ops with torch.compile
343+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
344+
:link: ../recipes/foreach_map.py
345+
:tags: Model-Optimization
346+
338347
.. Using User-Defined Triton Kernels with ``torch.compile``
339348
340349
.. customcarditem::

0 commit comments

Comments
 (0)