|
| 1 | +""" |
| 2 | +(beta) Explicit horizontal fusion with foreach_map and torch.compile |
| 3 | +============================================================ |
| 4 | +
|
| 5 | +**Author:** `Michael Lazos <https://github.com/mlazos>`_ |
| 6 | +""" |
| 7 | + |
| 8 | +######################################################### |
| 9 | +# Horizontal fusion is a key optimization in ML compilers. In eager, |
| 10 | +# this is typically expressed using the torch._foreach* ops which parallelizes |
| 11 | +# operations across a list of tensors. However, supporting all possible permutations |
| 12 | +# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map |
| 13 | +# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach |
| 14 | +# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer |
| 15 | +# with ``foreach_map`` to generate a fully fused kernel. |
| 16 | +# |
| 17 | +# |
| 18 | +# .. note:: |
| 19 | +# |
| 20 | +# This tutorial requires PyTorch 2.7.0 or later. |
| 21 | + |
| 22 | +##################################################################### |
| 23 | +# Model Setup |
| 24 | +# ~~~~~~~~~~~~~~~~~~~~~ |
| 25 | +# For this example, we'll use a simple sequence of linear layers. |
| 26 | +# We instantiate an independent copy to compare the two optimizer implementations. |
| 27 | +# |
| 28 | +import torch |
| 29 | + |
| 30 | +# exit cleanly if we are on a device that doesn't support ``torch.compile`` |
| 31 | +if torch.cuda.get_device_capability() < (7, 0): |
| 32 | + print("Exiting because torch.compile is not supported on this device.") |
| 33 | + import sys |
| 34 | + sys.exit(0) |
| 35 | + |
| 36 | +# Create simple model |
| 37 | +model = torch.nn.Sequential( |
| 38 | + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
| 39 | +) |
| 40 | +model_copy = torch.nn.Sequential( |
| 41 | + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] |
| 42 | +) |
| 43 | +input = torch.rand(1024, device="cuda") |
| 44 | + |
| 45 | +# run forward pass |
| 46 | +output = model(input) |
| 47 | +output_copy = model_copy(input) |
| 48 | + |
| 49 | +# run backward to populate the grads for our optimizer below |
| 50 | +output.sum().backward() |
| 51 | +output_copy.sum().backward() |
| 52 | + |
| 53 | +##################################################################### |
| 54 | +# Helper functions for foreach_map implementation |
| 55 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 56 | +# |
| 57 | +# In this section, we'll begin our implementation of the Adam optimizer. |
| 58 | +# |
| 59 | +from torch._higher_order_ops.foreach_map import foreach_map |
| 60 | + |
| 61 | +# Helper function to extract optimizer states from a torch.optim.Adam instance |
| 62 | +def get_inputs(optim): |
| 63 | + steps = [] |
| 64 | + params = [] |
| 65 | + grads = [] |
| 66 | + exp_avgs = [] |
| 67 | + exp_avg_sqs = [] |
| 68 | + for group in optim.param_groups: |
| 69 | + for p in group["params"]: |
| 70 | + params.append(p) |
| 71 | + grads.append(p.grad) |
| 72 | + state = optim.state[p] |
| 73 | + exp_avgs.append(state["exp_avg"]) |
| 74 | + exp_avg_sqs.append(state["exp_avg_sq"]) |
| 75 | + steps.append(state["step"]) |
| 76 | + |
| 77 | + return steps, params, exp_avgs, exp_avg_sqs |
| 78 | + |
| 79 | + |
| 80 | +# Functions to update the different optimizer states |
| 81 | +def update_exp_avg_sq(exp_avg_sq, grad, beta2): |
| 82 | + return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2) |
| 83 | + |
| 84 | +def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps): |
| 85 | + bias_correction1 = 1 - torch.pow(beta1, step) |
| 86 | + bias_correction2 = (1 - torch.pow(beta2, step)).sqrt() |
| 87 | + step_size = (lr / bias_correction1).neg() |
| 88 | + denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size) |
| 89 | + return torch.add(param, torch.div(exp_avg, denom)) |
| 90 | + |
| 91 | +# Our full Adam implementation |
| 92 | +def foreach_map_adam( |
| 93 | + steps, |
| 94 | + params, |
| 95 | + exp_avgs, |
| 96 | + exp_avg_sqs, |
| 97 | + weight_decay=0, |
| 98 | + beta1=0.9, |
| 99 | + beta2=0.999, |
| 100 | + lr=1e-3, |
| 101 | + eps=1e-8, |
| 102 | +): |
| 103 | + with torch.no_grad(): |
| 104 | + grads = [param.grad for param in params] |
| 105 | + # update step |
| 106 | + updated_steps = foreach_map(lambda x: x + 1, steps) |
| 107 | + torch._foreach_copy_(steps, updated_steps) |
| 108 | + |
| 109 | + if weight_decay != 0: |
| 110 | + foreach_map(torch.add, (grads,), alpha=weight_decay) |
| 111 | + |
| 112 | + # Higher-order operators (HOPs) cannot have multiple outputs at the moment |
| 113 | + # need to call foreach_map once for each output |
| 114 | + exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1) |
| 115 | + exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2) |
| 116 | + params_updated = foreach_map( |
| 117 | + update_param, |
| 118 | + params, |
| 119 | + steps, |
| 120 | + exp_avgs_updated, |
| 121 | + exp_avgs_sq_updated, |
| 122 | + beta1, |
| 123 | + beta2, |
| 124 | + lr, |
| 125 | + eps, |
| 126 | + ) |
| 127 | + # Higher-order operators (HOPs) don't support input mutation today |
| 128 | + # so manually update the states in-place |
| 129 | + torch._foreach_copy_(exp_avgs, exp_avgs_updated) |
| 130 | + torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated) |
| 131 | + torch._foreach_copy_(params, params_updated) |
| 132 | + return |
| 133 | + |
| 134 | +##################################################################### |
| 135 | +# Setting up and running the compiled kernel |
| 136 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 137 | +# |
| 138 | +# In this section, we'll run our Adam optimizer |
| 139 | +# and compare the results |
| 140 | +# |
| 141 | +# .. note:: |
| 142 | +# |
| 143 | +# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher. |
| 144 | +opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01)) |
| 145 | +opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01)) |
| 146 | + |
| 147 | +# warm up the optimizer state dict |
| 148 | +opt_eager.step() |
| 149 | +opt_eager_copy.step() |
| 150 | + |
| 151 | +inputs = get_inputs(opt_eager_copy) |
| 152 | +compiled_adam = torch.compile(foreach_map_adam) |
| 153 | + |
| 154 | +# optionally view the output code |
| 155 | +torch._logging.set_logs(output_code=True) |
| 156 | + |
| 157 | +# Warmup runs to compile the function |
| 158 | +for _ in range(5): |
| 159 | + opt_eager.step() |
| 160 | + compiled_adam(*inputs) |
| 161 | + |
| 162 | +for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]): |
| 163 | + torch.allclose(eager_p, compile_p) |
| 164 | + |
| 165 | +# Benchmark performance |
| 166 | + |
| 167 | + # Let's define a helpful benchmarking function: |
| 168 | +import torch.utils.benchmark as benchmark |
| 169 | + |
| 170 | +def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
| 171 | + t0 = benchmark.Timer( |
| 172 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 173 | + ) |
| 174 | + return t0.blocked_autorange().mean * 1e6 |
| 175 | + |
| 176 | +eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step) |
| 177 | +compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs)) |
| 178 | + |
| 179 | +assert eager_runtime > compiled_runtime |
| 180 | + |
| 181 | +print(f"eager runtime: {eager_runtime}us") |
| 182 | +print(f"compiled runtime: {compiled_runtime}us") |
| 183 | + |
| 184 | + |
| 185 | + |
| 186 | +###################################################################### |
| 187 | +# Conclusion |
| 188 | +# ~~~~~~~~~~ |
| 189 | +# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map. |
| 190 | +# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam |
| 191 | +# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide |
| 192 | +# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a |
| 193 | +# valuable resource for developers looking to improve the performance of their models with horizontal fusion. |
| 194 | +# |
| 195 | +# See also: |
| 196 | +# |
| 197 | +# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer. |
| 198 | +# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer. |
0 commit comments