|
| 1 | +# This file is copied from https://github.com/pytorch/pytorch/blob/master/torch/utils/checkpoint.py. |
| 2 | +# PyTorch/XLA needs to add `optimization_barrier` before saving the input for the backward hence we |
| 3 | +# slightly modify the upstream version of the checkpoint util function. |
| 4 | +import torch |
| 5 | +import warnings |
| 6 | +import torch_xla.core.xla_model as xm |
| 7 | +from torch.utils.checkpoint import detach_variable, check_backward_validity, get_device_states, set_device_states |
| 8 | +from typing import Any, Iterable, List, Tuple, Union |
| 9 | + |
| 10 | + |
| 11 | +class CheckpointFunction(torch.autograd.Function): |
| 12 | + |
| 13 | + @staticmethod |
| 14 | + def forward(ctx, run_function, preserve_rng_state, *args): |
| 15 | + check_backward_validity(args) |
| 16 | + ctx.run_function = run_function |
| 17 | + ctx.preserve_rng_state = preserve_rng_state |
| 18 | + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. |
| 19 | + ctx.gpu_autocast_kwargs = { |
| 20 | + "enabled": torch.is_autocast_enabled(), |
| 21 | + "dtype": torch.get_autocast_gpu_dtype(), |
| 22 | + "cache_enabled": torch.is_autocast_cache_enabled() |
| 23 | + } |
| 24 | + ctx.cpu_autocast_kwargs = { |
| 25 | + "enabled": torch.is_autocast_cpu_enabled(), |
| 26 | + "dtype": torch.get_autocast_cpu_dtype(), |
| 27 | + "cache_enabled": torch.is_autocast_cache_enabled() |
| 28 | + } |
| 29 | + if preserve_rng_state: |
| 30 | + ctx.fwd_cpu_state = torch.get_rng_state() |
| 31 | + # Don't eagerly initialize the cuda context by accident. |
| 32 | + # (If the user intends that the context is initialized later, within their |
| 33 | + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, |
| 34 | + # we have no way to anticipate this will happen before we run the function.) |
| 35 | + ctx.had_cuda_in_fwd = False |
| 36 | + if torch.cuda._initialized: |
| 37 | + ctx.had_cuda_in_fwd = True |
| 38 | + ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) |
| 39 | + |
| 40 | + # Save non-tensor inputs in ctx, keep a placeholder None for tensors |
| 41 | + # to be filled out during the backward. |
| 42 | + ctx.inputs = [] |
| 43 | + ctx.tensor_indices = [] |
| 44 | + tensor_inputs = [] |
| 45 | + tensor_outputs = [] |
| 46 | + for i, arg in enumerate(args): |
| 47 | + if torch.is_tensor(arg): |
| 48 | + tensor_inputs.append(arg) |
| 49 | + ctx.tensor_indices.append(i) |
| 50 | + ctx.inputs.append(None) |
| 51 | + else: |
| 52 | + ctx.inputs.append(arg) |
| 53 | + |
| 54 | + with torch.no_grad(): |
| 55 | + outputs = run_function(*args) |
| 56 | + if torch.is_tensor(outputs): |
| 57 | + tensor_outputs.append(outputs) |
| 58 | + # tensor is Iterable so we need to avoid iterating through tensor |
| 59 | + elif isinstance(outputs, Iterable): |
| 60 | + for output in outputs: |
| 61 | + if torch.is_tensor(output): |
| 62 | + tensor_outputs.append(output) |
| 63 | + |
| 64 | + xm.optimization_barrier_(tensor_inputs + tensor_outputs) |
| 65 | + ctx.save_for_backward(*tensor_inputs) |
| 66 | + |
| 67 | + return outputs |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def backward(ctx, *args): |
| 71 | + if not torch.autograd._is_checkpoint_valid(): |
| 72 | + raise RuntimeError( |
| 73 | + "Checkpointing is not compatible with .grad() or when an `inputs` parameter" |
| 74 | + " is passed to .backward(). Please use .backward() and do not pass its `inputs`" |
| 75 | + " argument.") |
| 76 | + # Copy the list to avoid modifying original list. |
| 77 | + inputs = list(ctx.inputs) |
| 78 | + tensor_indices = ctx.tensor_indices |
| 79 | + tensors = ctx.saved_tensors |
| 80 | + |
| 81 | + # Fill in inputs with appropriate saved tensors. |
| 82 | + for i, idx in enumerate(tensor_indices): |
| 83 | + inputs[idx] = tensors[i] |
| 84 | + |
| 85 | + # Stash the surrounding rng state, and mimic the state that was |
| 86 | + # present at this time during forward. Restore the surrounding state |
| 87 | + # when we're done. |
| 88 | + rng_devices = [] |
| 89 | + if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: |
| 90 | + rng_devices = ctx.fwd_gpu_devices |
| 91 | + with torch.random.fork_rng( |
| 92 | + devices=rng_devices, enabled=ctx.preserve_rng_state): |
| 93 | + if ctx.preserve_rng_state: |
| 94 | + torch.set_rng_state(ctx.fwd_cpu_state) |
| 95 | + if ctx.had_cuda_in_fwd: |
| 96 | + set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) |
| 97 | + detached_inputs = detach_variable(tuple(inputs)) |
| 98 | + with torch.enable_grad(), \ |
| 99 | + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ |
| 100 | + torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): |
| 101 | + outputs = ctx.run_function(*detached_inputs) |
| 102 | + |
| 103 | + if isinstance(outputs, torch.Tensor): |
| 104 | + outputs = (outputs,) |
| 105 | + |
| 106 | + # run backward() with only tensor that requires grad |
| 107 | + outputs_with_grad = [] |
| 108 | + args_with_grad = [] |
| 109 | + for i in range(len(outputs)): |
| 110 | + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
| 111 | + outputs_with_grad.append(outputs[i]) |
| 112 | + args_with_grad.append(args[i]) |
| 113 | + if len(outputs_with_grad) == 0: |
| 114 | + raise RuntimeError("none of output has requires_grad=True," |
| 115 | + " this checkpoint() is not necessary") |
| 116 | + torch.autograd.backward(outputs_with_grad, args_with_grad) |
| 117 | + grads = tuple( |
| 118 | + inp.grad if isinstance(inp, torch.Tensor) else None |
| 119 | + for inp in detached_inputs) |
| 120 | + |
| 121 | + return (None, None) + grads |
| 122 | + |
| 123 | + |
| 124 | +def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): |
| 125 | + r"""Checkpoint a model or part of the model |
| 126 | +
|
| 127 | + Checkpointing works by trading compute for memory. Rather than storing all |
| 128 | + intermediate activations of the entire computation graph for computing |
| 129 | + backward, the checkpointed part does **not** save intermediate activations, |
| 130 | + and instead recomputes them in backward pass. It can be applied on any part |
| 131 | + of a model. |
| 132 | +
|
| 133 | + Specifically, in the forward pass, :attr:`function` will run in |
| 134 | + :func:`torch.no_grad` manner, i.e., not storing the intermediate |
| 135 | + activations. Instead, the forward pass saves the inputs tuple and the |
| 136 | + :attr:`function` parameter. In the backwards pass, the saved inputs and |
| 137 | + :attr:`function` is retrieved, and the forward pass is computed on |
| 138 | + :attr:`function` again, now tracking the intermediate activations, and then |
| 139 | + the gradients are calculated using these activation values. |
| 140 | +
|
| 141 | + The output of :attr:`function` can contain non-Tensor values and gradient |
| 142 | + recording is only performed for the Tensor values. Note that if the output |
| 143 | + consists of nested structures (ex: custom objects, lists, dicts etc.) |
| 144 | + consisting of Tensors, these Tensors nested in custom structures will not |
| 145 | + be considered as part of autograd. |
| 146 | +
|
| 147 | +
|
| 148 | + .. warning:: |
| 149 | + If :attr:`function` invocation during backward does anything different |
| 150 | + than the one during forward, e.g., due to some global variable, the |
| 151 | + checkpointed version won't be equivalent, and unfortunately it can't be |
| 152 | + detected. |
| 153 | +
|
| 154 | + .. warning:: |
| 155 | + If ``use_reentrant=True`` is specified, then if the checkpointed segment |
| 156 | + contains tensors detached from the computational graph by `detach()` or |
| 157 | + `torch.no_grad()`, the backward pass will raise an error. This is |
| 158 | + because `checkpoint` makes all the outputs require gradients which |
| 159 | + causes issues when a tensor is defined to have no gradient in the model. |
| 160 | + To circumvent this, detach the tensors outside of the `checkpoint` |
| 161 | + function. Note that the checkpointed segment can contain tensors |
| 162 | + detached from the computational graph if ``use_reentrant=False`` is |
| 163 | + specified. |
| 164 | +
|
| 165 | + .. warning:: |
| 166 | + If ``use_reentrant=True`` is specified, at least one of the inputs needs |
| 167 | + to have :code:`requires_grad=True` if grads are needed for model inputs, |
| 168 | + otherwise the checkpointed part of the model won't have gradients. At |
| 169 | + least one of the outputs needs to have :code:`requires_grad=True` as |
| 170 | + well. Note that this does not apply if ``use_reentrant=False`` is |
| 171 | + specified. |
| 172 | +
|
| 173 | + .. warning:: |
| 174 | + If ``use_reentrant=True`` is specified, checkpointing currently only |
| 175 | + supports :func:`torch.autograd.backward` and only if its `inputs` |
| 176 | + argument is not passed. :func:`torch.autograd.grad` |
| 177 | + is not supported. If ``use_reentrant=False`` is specified, checkpointing |
| 178 | + will work with :func:`torch.autograd.grad`. |
| 179 | +
|
| 180 | + Args: |
| 181 | + function: describes what to run in the forward pass of the model or |
| 182 | + part of the model. It should also know how to handle the inputs |
| 183 | + passed as the tuple. For example, in LSTM, if user passes |
| 184 | + ``(activation, hidden)``, :attr:`function` should correctly use the |
| 185 | + first input as ``activation`` and the second input as ``hidden`` |
| 186 | + preserve_rng_state(bool, optional, default=True): Omit stashing and restoring |
| 187 | + the RNG state during each checkpoint. |
| 188 | + use_reentrant(bool, optional, default=True): Use checkpointing |
| 189 | + implementation that requires re-entrant autograd. |
| 190 | + If ``use_reentrant=False`` is specified, ``checkpoint`` will use an |
| 191 | + implementation that does not require re-entrant autograd. This |
| 192 | + allows ``checkpoint`` to support additional functionality, such as |
| 193 | + working as expected with ``torch.autograd.grad``. Note that future |
| 194 | + versions of PyTorch will default to ``use_reentrant=False``. |
| 195 | + args: tuple containing inputs to the :attr:`function` |
| 196 | +
|
| 197 | + Returns: |
| 198 | + Output of running :attr:`function` on :attr:`*args` |
| 199 | + """ |
| 200 | + # Hack to mix *args with **kwargs in a python 2.7-compliant way |
| 201 | + preserve = kwargs.pop('preserve_rng_state', True) |
| 202 | + if kwargs: |
| 203 | + raise ValueError("Unexpected keyword arguments: " + |
| 204 | + ",".join(arg for arg in kwargs)) |
| 205 | + |
| 206 | + if use_reentrant: |
| 207 | + return CheckpointFunction.apply(function, preserve, *args) |
| 208 | + else: |
| 209 | + raise ValueError("XLA currently does not support use_reentrant==False") |
0 commit comments