Skip to content

Commit ecc92bd

Browse files
committed
Use .backward() with in-place grad mutations for the GA API
1 parent 2e4f073 commit ecc92bd

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

test/spmd/test_train_spmd_linear_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self):
7474
# Verify that the model losses are not zero, and that the runs match.
7575
assert all(loss != 0 for loss in baseline_grad_acc_losses)
7676
assert all(
77-
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
78-
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
77+
torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8)
78+
for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses,
7979
loop_grad_acc_losses))
8080

81+
if not SKIP_GRADIENT_CHECKPOINTING:
82+
print('Training loop with XLA\'s `While` gradient accumulation and '
83+
'gradient checkpointing.')
84+
with extended_argv(
85+
COMMON_GRAD_ACC_ARGS +
86+
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
87+
loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc()
88+
assert all(
89+
torch.allclose(
90+
baseline_loss,
91+
loop_grad_acc_grad_chkpt_loss,
92+
rtol=1e-4,
93+
atol=1e-8)
94+
for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip(
95+
baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses))
96+
8197

8298
if __name__ == '__main__':
8399
parser = argparse.ArgumentParser()

torch_xla/experimental/gradient_accumulation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ def _prepare_fake_tensors(
181181
grads = [param.grad for param in params]
182182
body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors,
183183
*fake_carried_tensors, *params, *grads)
184+
# TODO - Fake the gradients once we are able to create placeholder tensors.
185+
# Since the body is expected to do an in-place mutation of the gradients, we
186+
# clone the gradients and use that as an input to the body. This will ensure
187+
# that we retain a device data IR node in the graph. The cloned gradient will
188+
# be updated to denote an IR operation (e.g. %add), and that can not be
189+
# captured as a device data input for the other required computations, namely
190+
# the condition and init for the XLA while loop.
191+
for param in params:
192+
param.grad = param.grad.clone()
184193
body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors),
185194
tuple(fake_carried_tensors), tuple(params),
186195
tuple(grads))
@@ -375,10 +384,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
375384
else:
376385
loss, *carried_tensors = result
377386
loss /= context.num_gradient_steps
378-
gradients = torch.autograd.grad(loss, model_parameters)
379-
acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)]
380-
return (iteri, loss, *iterable_tensors, *carried_tensors, *params,
381-
*acc_grads)
387+
loss.backward()
388+
grads = [param.grad for param in params]
389+
return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads)
382390

383391
if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config():
384392
warnings.warn(

0 commit comments

Comments
 (0)