Skip to content

Commit 16498c3

Browse files
authored
Add Checkpoint api to use optimization barrier (#3524)
* Add PyTorch/XLA version of checkpoint with optimization barrier * Add checkpoint code * update test
1 parent dbc6760 commit 16498c3

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

test/run_tests.sh

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ function run_all_tests {
9393
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9494
run_eager_debug python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
9595
run_async_rng python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
96+
# TODO: enable this test after tf update, currently optimization_barrier does not
97+
# work on CPU.
98+
# run_test python3 "$CDIR/test_checkpoint.py"
9699
run_test python3 "$CDIR/test_mp_replication.py"
97100
run_test python3 "$CDIR/test_mp_all_to_all.py"
98101
run_test python3 "$CDIR/test_mp_collective_permute.py"

test/test_checkpoint.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch_xla.core.xla_model as xm
3+
import torch_xla.debug.metrics as met
4+
import torch_xla
5+
import torch_xla.utils.checkpoint as checkpoint
6+
7+
8+
def run():
9+
device = xm.xla_device()
10+
model = torch.nn.ModuleList([
11+
torch.nn.Sequential(
12+
torch.nn.Conv2d(1024, 1024, 1),
13+
torch.nn.ReLU(),
14+
torch.nn.Conv2d(1024, 1024, 1),
15+
torch.nn.ReLU(),
16+
) for _ in range(2)
17+
]).to(device)
18+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0)
19+
20+
for step in range(20):
21+
dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
22+
optimizer.zero_grad()
23+
x = dummy_data
24+
for n_l, layer in enumerate(model):
25+
if n_l > 0:
26+
x = checkpoint.checkpoint(layer, x)
27+
else:
28+
x = layer(x)
29+
dummy_loss = x.sum()
30+
dummy_loss.backward()
31+
optimizer.step()
32+
xm.mark_step()
33+
34+
35+
if __name__ == "__main__":
36+
run()

torch_xla/utils/checkpoint.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# This file is copied from https://github.com/pytorch/pytorch/blob/master/torch/utils/checkpoint.py.
2+
# PyTorch/XLA needs to add `optimization_barrier` before saving the input for the backward hence we
3+
# slightly modify the upstream version of the checkpoint util function.
4+
import torch
5+
import warnings
6+
import torch_xla.core.xla_model as xm
7+
from torch.utils.checkpoint import detach_variable, check_backward_validity, get_device_states, set_device_states
8+
from typing import Any, Iterable, List, Tuple, Union
9+
10+
11+
class CheckpointFunction(torch.autograd.Function):
12+
13+
@staticmethod
14+
def forward(ctx, run_function, preserve_rng_state, *args):
15+
check_backward_validity(args)
16+
ctx.run_function = run_function
17+
ctx.preserve_rng_state = preserve_rng_state
18+
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
19+
ctx.gpu_autocast_kwargs = {
20+
"enabled": torch.is_autocast_enabled(),
21+
"dtype": torch.get_autocast_gpu_dtype(),
22+
"cache_enabled": torch.is_autocast_cache_enabled()
23+
}
24+
ctx.cpu_autocast_kwargs = {
25+
"enabled": torch.is_autocast_cpu_enabled(),
26+
"dtype": torch.get_autocast_cpu_dtype(),
27+
"cache_enabled": torch.is_autocast_cache_enabled()
28+
}
29+
if preserve_rng_state:
30+
ctx.fwd_cpu_state = torch.get_rng_state()
31+
# Don't eagerly initialize the cuda context by accident.
32+
# (If the user intends that the context is initialized later, within their
33+
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
34+
# we have no way to anticipate this will happen before we run the function.)
35+
ctx.had_cuda_in_fwd = False
36+
if torch.cuda._initialized:
37+
ctx.had_cuda_in_fwd = True
38+
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
39+
40+
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
41+
# to be filled out during the backward.
42+
ctx.inputs = []
43+
ctx.tensor_indices = []
44+
tensor_inputs = []
45+
tensor_outputs = []
46+
for i, arg in enumerate(args):
47+
if torch.is_tensor(arg):
48+
tensor_inputs.append(arg)
49+
ctx.tensor_indices.append(i)
50+
ctx.inputs.append(None)
51+
else:
52+
ctx.inputs.append(arg)
53+
54+
with torch.no_grad():
55+
outputs = run_function(*args)
56+
if torch.is_tensor(outputs):
57+
tensor_outputs.append(outputs)
58+
# tensor is Iterable so we need to avoid iterating through tensor
59+
elif isinstance(outputs, Iterable):
60+
for output in outputs:
61+
if torch.is_tensor(output):
62+
tensor_outputs.append(output)
63+
64+
xm.optimization_barrier_(tensor_inputs + tensor_outputs)
65+
ctx.save_for_backward(*tensor_inputs)
66+
67+
return outputs
68+
69+
@staticmethod
70+
def backward(ctx, *args):
71+
if not torch.autograd._is_checkpoint_valid():
72+
raise RuntimeError(
73+
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
74+
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
75+
" argument.")
76+
# Copy the list to avoid modifying original list.
77+
inputs = list(ctx.inputs)
78+
tensor_indices = ctx.tensor_indices
79+
tensors = ctx.saved_tensors
80+
81+
# Fill in inputs with appropriate saved tensors.
82+
for i, idx in enumerate(tensor_indices):
83+
inputs[idx] = tensors[i]
84+
85+
# Stash the surrounding rng state, and mimic the state that was
86+
# present at this time during forward. Restore the surrounding state
87+
# when we're done.
88+
rng_devices = []
89+
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
90+
rng_devices = ctx.fwd_gpu_devices
91+
with torch.random.fork_rng(
92+
devices=rng_devices, enabled=ctx.preserve_rng_state):
93+
if ctx.preserve_rng_state:
94+
torch.set_rng_state(ctx.fwd_cpu_state)
95+
if ctx.had_cuda_in_fwd:
96+
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
97+
detached_inputs = detach_variable(tuple(inputs))
98+
with torch.enable_grad(), \
99+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
100+
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
101+
outputs = ctx.run_function(*detached_inputs)
102+
103+
if isinstance(outputs, torch.Tensor):
104+
outputs = (outputs,)
105+
106+
# run backward() with only tensor that requires grad
107+
outputs_with_grad = []
108+
args_with_grad = []
109+
for i in range(len(outputs)):
110+
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
111+
outputs_with_grad.append(outputs[i])
112+
args_with_grad.append(args[i])
113+
if len(outputs_with_grad) == 0:
114+
raise RuntimeError("none of output has requires_grad=True,"
115+
" this checkpoint() is not necessary")
116+
torch.autograd.backward(outputs_with_grad, args_with_grad)
117+
grads = tuple(
118+
inp.grad if isinstance(inp, torch.Tensor) else None
119+
for inp in detached_inputs)
120+
121+
return (None, None) + grads
122+
123+
124+
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
125+
r"""Checkpoint a model or part of the model
126+
127+
Checkpointing works by trading compute for memory. Rather than storing all
128+
intermediate activations of the entire computation graph for computing
129+
backward, the checkpointed part does **not** save intermediate activations,
130+
and instead recomputes them in backward pass. It can be applied on any part
131+
of a model.
132+
133+
Specifically, in the forward pass, :attr:`function` will run in
134+
:func:`torch.no_grad` manner, i.e., not storing the intermediate
135+
activations. Instead, the forward pass saves the inputs tuple and the
136+
:attr:`function` parameter. In the backwards pass, the saved inputs and
137+
:attr:`function` is retrieved, and the forward pass is computed on
138+
:attr:`function` again, now tracking the intermediate activations, and then
139+
the gradients are calculated using these activation values.
140+
141+
The output of :attr:`function` can contain non-Tensor values and gradient
142+
recording is only performed for the Tensor values. Note that if the output
143+
consists of nested structures (ex: custom objects, lists, dicts etc.)
144+
consisting of Tensors, these Tensors nested in custom structures will not
145+
be considered as part of autograd.
146+
147+
148+
.. warning::
149+
If :attr:`function` invocation during backward does anything different
150+
than the one during forward, e.g., due to some global variable, the
151+
checkpointed version won't be equivalent, and unfortunately it can't be
152+
detected.
153+
154+
.. warning::
155+
If ``use_reentrant=True`` is specified, then if the checkpointed segment
156+
contains tensors detached from the computational graph by `detach()` or
157+
`torch.no_grad()`, the backward pass will raise an error. This is
158+
because `checkpoint` makes all the outputs require gradients which
159+
causes issues when a tensor is defined to have no gradient in the model.
160+
To circumvent this, detach the tensors outside of the `checkpoint`
161+
function. Note that the checkpointed segment can contain tensors
162+
detached from the computational graph if ``use_reentrant=False`` is
163+
specified.
164+
165+
.. warning::
166+
If ``use_reentrant=True`` is specified, at least one of the inputs needs
167+
to have :code:`requires_grad=True` if grads are needed for model inputs,
168+
otherwise the checkpointed part of the model won't have gradients. At
169+
least one of the outputs needs to have :code:`requires_grad=True` as
170+
well. Note that this does not apply if ``use_reentrant=False`` is
171+
specified.
172+
173+
.. warning::
174+
If ``use_reentrant=True`` is specified, checkpointing currently only
175+
supports :func:`torch.autograd.backward` and only if its `inputs`
176+
argument is not passed. :func:`torch.autograd.grad`
177+
is not supported. If ``use_reentrant=False`` is specified, checkpointing
178+
will work with :func:`torch.autograd.grad`.
179+
180+
Args:
181+
function: describes what to run in the forward pass of the model or
182+
part of the model. It should also know how to handle the inputs
183+
passed as the tuple. For example, in LSTM, if user passes
184+
``(activation, hidden)``, :attr:`function` should correctly use the
185+
first input as ``activation`` and the second input as ``hidden``
186+
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
187+
the RNG state during each checkpoint.
188+
use_reentrant(bool, optional, default=True): Use checkpointing
189+
implementation that requires re-entrant autograd.
190+
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
191+
implementation that does not require re-entrant autograd. This
192+
allows ``checkpoint`` to support additional functionality, such as
193+
working as expected with ``torch.autograd.grad``. Note that future
194+
versions of PyTorch will default to ``use_reentrant=False``.
195+
args: tuple containing inputs to the :attr:`function`
196+
197+
Returns:
198+
Output of running :attr:`function` on :attr:`*args`
199+
"""
200+
# Hack to mix *args with **kwargs in a python 2.7-compliant way
201+
preserve = kwargs.pop('preserve_rng_state', True)
202+
if kwargs:
203+
raise ValueError("Unexpected keyword arguments: " +
204+
",".join(arg for arg in kwargs))
205+
206+
if use_reentrant:
207+
return CheckpointFunction.apply(function, preserve, *args)
208+
else:
209+
raise ValueError("XLA currently does not support use_reentrant==False")

0 commit comments

Comments
 (0)