Skip to content

Use .backward() with in-place grad mutations for the GA API #8768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,26 @@ def test_gradient_accumulation_matches(self):
# Verify that the model losses are not zero, and that the runs match.
assert all(loss != 0 for loss in baseline_grad_acc_losses)
assert all(
torch.allclose(baseline_loss, checkpointing_loss, rtol=1e-4, atol=1e-8)
for baseline_loss, checkpointing_loss in zip(baseline_grad_acc_losses,
torch.allclose(baseline_loss, loop_grad_acc_loss, rtol=1e-4, atol=1e-8)
for baseline_loss, loop_grad_acc_loss in zip(baseline_grad_acc_losses,
loop_grad_acc_losses))

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with XLA\'s `While` gradient accumulation and '
'gradient checkpointing.')
with extended_argv(
COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
loop_grad_acc_grad_chkpt_losses = train_and_evaluate_grad_acc()
assert all(
torch.allclose(
baseline_loss,
loop_grad_acc_grad_chkpt_loss,
rtol=1e-4,
atol=1e-8)
for baseline_loss, loop_grad_acc_grad_chkpt_loss in zip(
baseline_grad_acc_losses, loop_grad_acc_grad_chkpt_losses))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
16 changes: 12 additions & 4 deletions torch_xla/experimental/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ def _prepare_fake_tensors(
grads = [param.grad for param in params]
body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors,
*fake_carried_tensors, *params, *grads)
# TODO - Fake the gradients once we are able to create placeholder tensors.
# Since the body is expected to do an in-place mutation of the gradients, we
# clone the gradients and use that as an input to the body. This will ensure
# that we retain a device data IR node in the graph. The cloned gradient will
# be updated to denote an IR operation (e.g. %add), and that can not be
# captured as a device data input for the other required computations, namely
# the condition and init for the XLA while loop.
for param in params:
param.grad = param.grad.clone()
body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors),
tuple(fake_carried_tensors), tuple(params),
tuple(grads))
Expand Down Expand Up @@ -375,10 +384,9 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
else:
loss, *carried_tensors = result
loss /= context.num_gradient_steps
gradients = torch.autograd.grad(loss, model_parameters)
acc_grads = [prev_grad + grad for prev_grad, grad in zip(grads, gradients)]
return (iteri, loss, *iterable_tensors, *carried_tensors, *params,
*acc_grads)
loss.backward()
grads = [param.grad for param in params]
return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the grads variable in the body_fn if we are doing an in-place update? params argument should have everything we need right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. When tracing the body, it'll identify the XLA device data inputs for the gradients, and so, we need to make sure that those are part of the output (T -> T) for both the body and condition - so I added those explicitly since it was clearer when working with torch.autograd.grad. Now that we have changed to .backward() and in-place, we should ideally leverage the hoisted vars to achieve this automatically. This can be out-of-scope in this PR (unblocking chkpt), since we should eventually revamp/unify most of the internals with while_loop/fori_loop/scan as we better understand the existing LTC limitations/complications with the XLA while op. I can create a follow-up issue for this, thanks for raising it.


if not torch_xla._XLAC._xla_get_enable_alias_with_buffer_donor_config():
warnings.warn(
Expand Down
Loading