Skip to content

Use .backward() with in-place grad mutations for the GA API #8768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2025

Conversation

rpsilva-aws
Copy link
Collaborator

In this PR, we switch to using .backward() instead of torch.autograd.grad due to #8729. We resolve this similarly to scan (https://github.com/rpsilva-aws/xla/blob/master/torch_xla/experimental/scan.py#L232), except in our case, since we have full control over the backward pass, we can explicitly do clone over the input gradients, knowingly that the underlying tensors will be updated in-place.

In short, my understanding of why we need to clone is because once we do an in-place mutation, the resulting IR node will not be a device data. Since we rely on the lowering context to collect all parameters in the mapping (device_parameter_id_tensor_mapping), this IR will be exempted, since it'll be an IR op (e.g. %add). This doesn't allow us to capture it as a parameter to other XLA while computations (condition and init). This is a known issue with this and the scan API, and we should eventually find a more robust way around it. If we did so, we could offer a more general and all-purpose while_loop/fori_loop API for users (and use that as part of scan and this gradient accumulation API).

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_chkpt branch from c49a1fd to 5922c4d Compare February 27, 2025 21:26
@rpsilva-aws
Copy link
Collaborator Author

Canceled the workflow, this is pending on #8758.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_grad_acc_chkpt branch from 5922c4d to ecc92bd Compare February 28, 2025 00:49
@rpsilva-aws rpsilva-aws marked this pull request as ready for review February 28, 2025 00:50
*acc_grads)
loss.backward()
grads = [param.grad for param in params]
return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the grads variable in the body_fn if we are doing an in-place update? params argument should have everything we need right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. When tracing the body, it'll identify the XLA device data inputs for the gradients, and so, we need to make sure that those are part of the output (T -> T) for both the body and condition - so I added those explicitly since it was clearer when working with torch.autograd.grad. Now that we have changed to .backward() and in-place, we should ideally leverage the hoisted vars to achieve this automatically. This can be out-of-scope in this PR (unblocking chkpt), since we should eventually revamp/unify most of the internals with while_loop/fori_loop/scan as we better understand the existing LTC limitations/complications with the XLA while op. I can create a follow-up issue for this, thanks for raising it.

@rpsilva-aws rpsilva-aws requested a review from bhavya01 March 3, 2025 18:09
@rpsilva-aws
Copy link
Collaborator Author

@bhavya01 Can you PTAL once you find the time? Thanks

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Mar 5, 2025

Kind reminder @tengyifei @bhavya01, so I have time to include a couple follow-ups in time for 2.7.0 - one of which is to leverage @tengyifei 's #8785, so the clone would simultaneously be for just staging (separating the IR that is fed to the body ctx, and the one that is used in the mapping).

@bhavya01 bhavya01 merged commit 4540d81 into pytorch:master Mar 5, 2025
23 checks passed
@rpsilva-aws rpsilva-aws deleted the rpsilva_grad_acc_chkpt branch March 5, 2025 20:26
pgmoka added a commit that referenced this pull request Mar 5, 2025
Use .backward() with in-place grad mutations for the GA API (#8768)

Use placeholder tensor in scan (#8785)

Pin update to 20250303 (#8788)

Co-authored-by: Chengji Yao <[email protected]>

correct linter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants