Skip to content

Commit 5922c4d

Browse files
committed
Use .backward() with in-place grad mutations for the GA API
1 parent 1ab8216 commit 5922c4d

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

test/spmd/test_train_spmd_linear_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self):
7474
# Verify that the model losses are not zero, and that the runs match.
7575
assert all(loss != 0 for loss in baseline_grad_acc_losses)
7676
assert all(
77-
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
78-
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
77+
torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8)
78+
for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses,
7979
loop_grad_acc_losses))
8080

81+
if not SKIP_GRADIENT_CHECKPOINTING:
82+
print('Training loop with XLA\'s `While` gradient accumulation and '
83+
'gradient checkpointing.')
84+
with extended_argv(
85+
COMMON_GRAD_ACC_ARGS +
86+
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
87+
loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc()
88+
assert all(
89+
torch.allclose(
90+
baseline_loss,
91+
loop_grad_acc_grad_chkpt_loss,
92+
rtol=1e-4,
93+
atol=1e-8)
94+
for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip(
95+
baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses))
96+
8197

8298
if __name__ == '__main__':
8399
parser = argparse.ArgumentParser()

torch_xla/experimental/gradient_accumulation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ def _prepare_fake_tensors(
179179

180180
body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors,
181181
*fake_carried_tensors, *params, *grads)
182+
# TODO - Fake the gradients once we are able to create placeholder tensors.
183+
# Since the body is expected to do an in-place mutation of the gradients, we
184+
# clone the gradients and use that as an input to the body. This will ensure
185+
# that we retain a device data IR node in the graph. The cloned gradient will
186+
# be updated to denote an IR operation (e.g. %add), and that can not be
187+
# captured as a device data input for the other required computations, namely
188+
# the condition and init for the XLA while loop.
189+
for param in params:
190+
param.grad = param.grad.clone()
182191
body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors),
183192
tuple(fake_carried_tensors), tuple(params),
184193
tuple(grads))
@@ -373,10 +382,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
373382
else:
374383
loss, *carried_tensors = result
375384
loss /= context.num_gradient_steps
376-
gradients = torch.autograd.grad(loss, model_parameters)
377-
acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)]
378-
return (iteri, loss, *iterable_tensors, *carried_tensors, *params,
379-
*acc_grads)
385+
loss.backward()
386+
grads = [param.grad for param in params]
387+
return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads)
380388

381389
init_grads = []
382390
# Initialize the gradients to zero.

0 commit comments

Comments
 (0)