Skip to content

Add Checkpoint api to use optimization barrier #3524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2022
Merged

Conversation

JackCaoG
Copy link
Collaborator

@ronghanghu
Copy link
Collaborator

This is great, thanks for adding it!

@JackCaoG
Copy link
Collaborator Author

I think this pr is ready to land. I want to also move checkpoint_sequential to xla module and copy some of the checkpoint test upstream to guarantee the correctness of the checkpointing. I also want to add some test to verify optimizaiton_barrier is in the hlo.

@ronghanghu We reverted the TF change since it break upstream CI build. I think we have one wheel that is built with new tf to resolve the speed regression and weird shape mismatched. Even if I merge this pr now, I don't recommed you to update the wheel until we reland the tf.

@JackCaoG JackCaoG requested a review from miladm April 26, 2022 02:33
@JackCaoG
Copy link
Collaborator Author

I manually verified this works on TPU and saw the expected memory saving. I will merge this pr for now. @miladm feel free to take a pass, I can address the review comment in the following pr. I also need to write a better documentation around this part of the code.

@ronghanghu
Copy link
Collaborator

Thanks for merging this PR!

@ronghanghu We reverted the TF change since it break upstream CI build. I think we have one wheel that is built with new tf to resolve the speed regression and weird shape mismatched. Even if I merge this pr now, I don't recommed you to update the wheel until we reland the tf.

I see. I'd be happy to try out the new wheel after the TF version PR #3523 is relanded.

@ronghanghu
Copy link
Collaborator

ronghanghu commented Apr 30, 2022

I tried testing out this API under the torch and torch_xla nightly 20220430 build and somehow I can no longer observe the memory saving after each step from xm.get_memory_info(device). I'm using libtpu 20220413 for this verification.

Specifically, I'm using the following environments on a v3-8 TPU VM

# torch, torchvision and torch_xla 20220430
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl

# libtpu 20220413
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

and running the following code ~/workspace/grad_ckpt_test/test_grad_checkpoint_20220430.py

import argparse
import torch
from torch_xla.utils.checkpoint import checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def run(grad_checkpoint):
    device = xm.xla_device()
    model = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
            )
            for _ in range(64)
        ]
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0)

    for step in range(200):
        dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
        optimizer.zero_grad()
        x = dummy_data
        for n_l, layer in enumerate(model):
            if n_l > 0 and grad_checkpoint:
                x = checkpoint(layer, x)
            else:
                x = layer(x)
        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()
        xm.mark_step()
        print(f"step {step}, free memory = {xm.get_memory_info(device)['kb_free']}")

    print(met.metrics_report())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--grad_checkpoint", type=int, required=True)
    args = parser.parse_args()
    run(args.grad_checkpoint)

For the case WITH gradient checkpointing (python3 ~/workspace/grad_ckpt_test/test_grad_checkpoint_20220430.py --grad_checkpoint 1), I get

step 199, free memory = 11286448                                                                                                                                                                                                                   
...
Metric: ExecuteTime     
  TotalSamples: 199   
  Accumulator: 54s592ms080.230us                                                                                                                                                                                                                   
  ValueRate: 814ms385.839us / second
  Rate: 3.02401 / second
  Percentiles: 1%=267ms689.554us; 5%=267ms872.899us; 10%=267ms932.555us; 20%=267ms058.380us; 50%=267ms301.069us; 80%=268ms257.735us; 90%=271ms969.259us; 95%=282ms257.366us; 99%=313ms871.710us

For the case WITHOUT gradient checkpointing (python3 ~/workspace/grad_ckpt_test/test_grad_checkpoint_20220430.py --grad_checkpoint 0), I get

step 199, free memory = 11514304                                                                                                                                                                                                                   
...
Metric: ExecuteTime     
  TotalSamples: 199     
  Accumulator: 41s320ms632.742us                                                                                                                                                                                                                   
  ValueRate: 787ms676.172us / second
  Rate: 3.78872 / second
  Percentiles: 1%=204ms234.382us; 5%=204ms322.918us; 10%=204ms442.026us; 20%=205ms596.435us; 50%=205ms979.589us; 80%=206ms134.639us; 90%=211ms601.292us; 95%=220ms159.325us; 99%=258ms019.974us

So the gradient checkpointing increases the execution time as expected, but now it no longer shows a reduction in memory usage. It even uses more memory as the free TPU memory after each iteration drops from 11514304 KB to 11286448 KB after each iteration.

This is different from our observation in #3486 (comment), which saves more memory and boosts the free TPU memory from 11514304 KB to 12901312 KB after each iteration. (The TPU memory usage stayed the same for the case without gradient checkpointing, but increased for the case with gradient checkpointing.)

I think this is perhaps related to the updated TensorFlow versions or other issues on the XRT side.

(I'm also verifying on my end whether this seemingly increased TPU memory consumption affects my real gradient checkpointing use cases.)

@ronghanghu
Copy link
Collaborator

ronghanghu commented May 1, 2022

Update: Although the TPU's free memory is smaller after each iteration under the torch and torch_xla nightly 20220430 wheels, this gradient checkpointing API seems to work well and allows me to fit larger batch sizes.

I suspect that the seemingly higher TPU memory usage at the end of each iteration does not reflect the peak memory usage during execution. I think this gradient checkpointing API still successfully reduces peak memory usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants