-
Notifications
You must be signed in to change notification settings - Fork 524
Add Checkpoint api to use optimization barrier #3524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is great, thanks for adding it! |
f16c243
to
9274644
Compare
I think this pr is ready to land. I want to also move @ronghanghu We reverted the TF change since it break upstream CI build. I think we have one wheel that is built with new tf to resolve the speed regression and weird shape mismatched. Even if I merge this pr now, I don't recommed you to update the wheel until we reland the tf. |
I manually verified this works on TPU and saw the expected memory saving. I will merge this pr for now. @miladm feel free to take a pass, I can address the review comment in the following pr. I also need to write a better documentation around this part of the code. |
Thanks for merging this PR!
I see. I'd be happy to try out the new wheel after the TF version PR #3523 is relanded. |
I tried testing out this API under the torch and torch_xla nightly 20220430 build and somehow I can no longer observe the memory saving after each step from Specifically, I'm using the following environments on a v3-8 TPU VM
and running the following code
For the case WITH gradient checkpointing (
For the case WITHOUT gradient checkpointing (
So the gradient checkpointing increases the execution time as expected, but now it no longer shows a reduction in memory usage. It even uses more memory as the free TPU memory after each iteration drops from 11514304 KB to 11286448 KB after each iteration. This is different from our observation in #3486 (comment), which saves more memory and boosts the free TPU memory from 11514304 KB to 12901312 KB after each iteration. (The TPU memory usage stayed the same for the case without gradient checkpointing, but increased for the case with gradient checkpointing.) I think this is perhaps related to the updated TensorFlow versions or other issues on the XRT side. (I'm also verifying on my end whether this seemingly increased TPU memory consumption affects my real gradient checkpointing use cases.) |
Update: Although the TPU's free memory is smaller after each iteration under the torch and torch_xla nightly 20220430 wheels, this gradient checkpointing API seems to work well and allows me to fit larger batch sizes. I suspect that the seemingly higher TPU memory usage at the end of each iteration does not reflect the peak memory usage during execution. I think this gradient checkpointing API still successfully reduces peak memory usage. |
FYI @ronghanghu