-
Notifications
You must be signed in to change notification settings - Fork 526
Use .backward() with in-place grad mutations for the GA API #8768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c49a1fd
to
5922c4d
Compare
Canceled the workflow, this is pending on #8758. |
5922c4d
to
ecc92bd
Compare
*acc_grads) | ||
loss.backward() | ||
grads = [param.grad for param in params] | ||
return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the grads variable in the body_fn if we are doing an in-place update? params argument should have everything we need right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. When tracing the body, it'll identify the XLA device data inputs for the gradients, and so, we need to make sure that those are part of the output (T -> T) for both the body and condition - so I added those explicitly since it was clearer when working with torch.autograd.grad
. Now that we have changed to .backward()
and in-place, we should ideally leverage the hoisted vars to achieve this automatically. This can be out-of-scope in this PR (unblocking chkpt), since we should eventually revamp/unify most of the internals with while_loop/fori_loop/scan as we better understand the existing LTC limitations/complications with the XLA while op. I can create a follow-up issue for this, thanks for raising it.
@bhavya01 Can you PTAL once you find the time? Thanks |
Kind reminder @tengyifei @bhavya01, so I have time to include a couple follow-ups in time for 2.7.0 - one of which is to leverage @tengyifei 's #8785, so the clone would simultaneously be for just staging (separating the IR that is fed to the body ctx, and the one that is used in the mapping). |
Use .backward() with in-place grad mutations for the GA API (#8768) Use placeholder tensor in scan (#8785) Pin update to 20250303 (#8788) Co-authored-by: Chengji Yao <[email protected]> correct linter
In this PR, we switch to using
.backward()
instead oftorch.autograd.grad
due to #8729. We resolve this similarly to scan (https://github.com/rpsilva-aws/xla/blob/master/torch_xla/experimental/scan.py#L232), except in our case, since we have full control over the backward pass, we can explicitly do clone over the input gradients, knowingly that the underlying tensors will be updated in-place.In short, my understanding of why we need to clone is because once we do an in-place mutation, the resulting IR node will not be a device data. Since we rely on the lowering context to collect all parameters in the mapping (
device_parameter_id_tensor_mapping
), this IR will be exempted, since it'll be an IR op (e.g. %add). This doesn't allow us to capture it as a parameter to other XLA while computations (condition and init). This is a known issue with this and the scan API, and we should eventually find a more robust way around it. If we did so, we could offer a more general and all-purpose while_loop/fori_loop API for users (and use that as part of scan and this gradient accumulation API).