Skip to content

Major speed regression in the nightly libtpu (20220308) compared to the 1.10 environment #3441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ronghanghu opened this issue Mar 23, 2022 · 33 comments
Assignees

Comments

@ronghanghu
Copy link
Collaborator

ronghanghu commented Mar 23, 2022

🐛 Bug

There is a notable speed regression in a few models with the nightly 20220308 PyTorch XLA builds. For example, in MAE ViT-L (code attached below) the training time increased from 0.46 s/iter to 0.61 s/iter on a v3-128 pod when changing to the nightly 20220308 version. after switching from tpu-vm-pt-1.10 to the 20220308 builds of torch, torchvision, torch_xla, and libtpu.

The speed issue is particularly related to libtpu: The speed regression above can be reproduced by installing libtpu_nightly-0.1.dev20220308 on the tpu-vm-pt-1.10 runtime environment while keeping torch, torchvision, and torch_xla as-is from tpu-vm-pt-1.10. (As long as libtpu_nightly-0.1.dev20220308 is used, the speed issue happens regardless of whether torch, torchvision, and torch_xla are the tpu-vm-pt-1.10 versions or the nightly 20220308 versions.)

Increased ExecuteTime metric in libtpu nightly 20220308: When running the very same codebase against two libtpu versions (nightly 20220308 vs tpu-vm-pt-1.10), the ExecuteTime increased dramatically. This speed issue happens on both v3-128 TPU pod and v3-8.

Code and the dumped HLO graphs are attached below to reproduce this speed drop.

On v3-128 pod

The v3-128 pod is the practical use case and suffers the most from this issue. Below are the metrics (running with fake data).

  • libtpu (tpu-vm-pt-1.10)
2022-03-23 16:22:20 10.164.1.16 [0] Metric: ExecuteTime
2022-03-23 16:22:20 10.164.1.16 [0]   TotalSamples: 6300
2022-03-23 16:22:20 10.164.1.16 [0]   Accumulator: 45m58s967ms609.476us
2022-03-23 16:22:20 10.164.1.16 [0]   ValueRate: 921ms981.695us / second
2022-03-23 16:22:20 10.164.1.16 [0]   Rate: 2.43071 / second
2022-03-23 16:22:20 10.164.1.16 [0]   Percentiles: 1%=032ms627.252us; 5%=342ms396.858us; 10%=346ms541.697us; 20%=347ms044.893us; 50%=350ms108.598us; 80%=404ms203.194us; 90%=448ms869.785us; 95%=489ms928.152us; 99%=580ms060.833us
  • libtpu (nightly 20220308)
2022-03-23 17:12:38 10.164.0.252 [0] Metric: ExecuteTime
2022-03-23 17:12:38 10.164.0.252 [0]   TotalSamples: 6300
2022-03-23 17:12:38 10.164.0.252 [0]   Accumulator: 55m10s504ms688.128us
2022-03-23 17:12:38 10.164.0.252 [0]   ValueRate: 937ms688.881us / second
2022-03-23 17:12:38 10.164.0.252 [0]   Rate: 1.89054 / second
2022-03-23 17:12:38 10.164.0.252 [0]   Percentiles: 1%=030ms738.875us; 5%=460ms803.746us; 10%=463ms835.271us; 20%=465ms076.200us; 50%=469ms870.575us; 80%=522ms858.229us; 90%=569ms181.951us; 95%=633ms203.662us; 99%=706ms484.193us

The 50% percentile increased from 350ms108.598us to 469ms870.575us (a major speed drop) and the rate dropped from 2.43071 / second to 1.89054 / second.

On v3-8

the v3-8 pod also suffers from a similar drop. Below are the metrics (running with fake data) under the same per-TPU-core batch size as in v3-128.

  • libtpu (tpu-vm-pt-1.10)
Metric: ExecuteTime
  TotalSamples: 20028
  Accumulator: 02h05m49s508ms836.695us
  ValueRate: 898ms831.846us / second
  Rate: 2.57446 / second
  Percentiles: 1%=331ms520.482us; 5%=333ms254.737us; 10%=336ms635.889us; 20%=339ms375.811us; 50%=348ms614.947us; 80%=359ms644.435us; 90%=364ms916.825us; 95%=368ms475.100us; 99%=388ms471.869us
  • libtpu (nightly 20220308)
Metric: ExecuteTime
  TotalSamples: 20028
  Accumulator: 02h11m30s275ms580.059us
  ValueRate: 920ms289.906us / second
  Rate: 2.36889 / second
  Percentiles: 1%=368ms370.199us; 5%=373ms378.164us; 10%=375ms382.477us; 20%=379ms886.261us; 50%=388ms388.937us; 80%=397ms830.420us; 90%=402ms653.380us; 95%=407ms905.002us; 99%=437ms029.747us

Here the 50% percentile increased from 348ms614.947us to 388ms388.937us (a smaller speed drop than on v3-128 but still quite notable).

The dumped HLO graph for the v3-8 case (using XLA_SAVE_TENSORS_FILE=$hlo_$(date +%Y-%m-%d_%H-%M-%S).txt XLA_SAVE_TENSORS_FMT=hlo; run only for 5 iterations to avoid saving too many outputs) as well as the stdout and stderr outputs containing the printed metrics above are uploaded to Google drive in

To Reproduce

  1. Allocate a v3-128 TPU VM pod (or a v3-8 TPU VM) from the tpu-vm-pt-1.10 environment
  2. Install the timm library via sudo pip3 install timm==0.4.12 on all VM nodes (e.g. through gcloud alpha compute tpus tpu-vm ssh --worker all)
  3. Download the codebase in https://drive.google.com/file/d/1-iXaPkiHr3c4nTqKx8c6zK1cx5ba_T0p/view?usp=sharing to all TPU VM nodes (e.g. put it under /checkpoint/ronghanghu/workspace/mae_tpu)
  4. install libtpu
  • To reproduce the case of tpu-vm-pt-1.10, don't do any change
  • To reproduce the case of libtpu nightly 20220308, install it via sudo pip3 uninstall -y libtpu_nightly && sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220308-py3-none-any.whl on all VM nodes

On v3-128 pod

To run on a v3-128 pod (print metrics and save stdout and stderr; running with fake data):

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-128_libtpu_pt110
MODEL=mae_vit_large_patch16
EPOCH=20
BATCH_SIZE_PER_TPU=32

TPU_NAME=debug-v3-128  # change to your TPU name

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod-server \
  --env XRT_MESH_CONNECT_WAIT=1200 --env PYTHONUNBUFFERED=1 -- \
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

The code will print the XLA metrics such as ExecuteTime after each training epoch.

On v3-8

To run on a v3-8 (print metrics and save stdout and stderr; running with fake data):

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_pt110
MODEL=mae_vit_large_patch16
EPOCH=4
BATCH_SIZE_PER_TPU=32

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

The code will print the XLA metrics such as ExecuteTime after each training epoch.

Expected behavior

The speed regression should not happen when switching to the nightly 20220308 version of libtpu.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-128 (or v3-8) TPU VM starting from tpu-vm-pt-1.10 (and install the nightly 20220308 version of libtpu)
  • torch_xla version: 1.10

Additional context

The larger speed drop on a v3-128 pod suggests that there could be significant overhead when running the same model on a pod vs on a v3-8 under the same per-TPU batch size. This is somewhat unexpected: the model itself does not have any across TPU communications such as all_gather in it and the only ops involving cross TPU communication is the gradient reduction before applying optimizer updates). So I believe it would be great to investigate and resolve this overhead as it severely hurts many practical use cases.

cc @JackCaoG

@JackCaoG
Copy link
Collaborator

@yeounoh Can you take a look?

@yeounoh
Copy link
Contributor

yeounoh commented Mar 25, 2022

Hi @ronghanghu thanks for reporting the issue. I was able to reproduce and also confirm the issue with the latest libtpu-nightly as well. We will continue debugging and keep you posted.

@ronghanghu
Copy link
Collaborator Author

@yeounoh That's awesome, thanks for your help!

@yeounoh
Copy link
Contributor

yeounoh commented Apr 14, 2022

Hi @ronghanghu , we were able to address an issue with one of the collective ops and it seems to yield slightly better performance (~10us improvement in ExecuteTime at 50%). FWIW, if you'd like to try latest torch_xla with newer libtpu-nightly, feel free to try sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl .

We will continue investigating and keep you posted.

@ronghanghu
Copy link
Collaborator Author

This is great to hear, thanks for the update, @yeounoh!

@yeounoh
Copy link
Contributor

yeounoh commented Apr 19, 2022

Hi @ronghanghu, realized that you were having an issue with the libtpu-nightly. Let me also take a look...

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 19, 2022

That's great, thank you!

@yeounoh
Copy link
Contributor

yeounoh commented Apr 21, 2022

Hi @ronghanghu , we were able to address the regression -- this PR will land sometime today, you can pull the change and do git submodule update --recursive to run the latest libtpu-nightly.

Here is the result I got using libtpu-nightly==0.1.dev20220413

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 29m17s386ms909.913us
  ValueRate: 902ms614.043us / second
  Rate: 2.57767 / second
  Percentiles: 1%=339ms099.956us; 5%=340ms390.282us; 10%=341ms488.621us; 20%=343ms063.581us; 50%=347ms510.279us; 80%=355ms273.082us; 90%=366ms405.235us; 95%=373ms337.098us; 99%=389ms789.397us

Would you be able to help us verify?

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 21, 2022

Great, I'll check it out. Thank you!

@yeounoh Just to double-check, to try out this new version, should I use the 20220422 PyTorch XLA build from tomorrow (that includes #3507) along with libtpu-nightly==0.1.dev20220422, or should I use 20220422 PyTorch XLA but still keep libtpu-nightly==0.1.dev20220413 as in your comments above?

@JackCaoG
Copy link
Collaborator

The tf is pin to 04/07 so I expect both dev20220422 and dev20220413 to work. If you run into issue with dev20220422 you can let us know and use dev20220413 while we debug the regression..again :D.

@yeounoh
Copy link
Contributor

yeounoh commented Apr 29, 2022

Hi @ronghanghu we've resolved the build issue regarding the tensorflow update. Once we land #3535, #3541, you should be able to checkout the latest master and use it with libtpu-nightly==0.1.dev20220413. Let me know if you have any question.

@ronghanghu
Copy link
Collaborator Author

@yeounoh Thanks! I'll check it out when these PRs are merged

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 30, 2022

Update: following up on this -- I tried out the nightly 20220430 wheels for torch_xla, torch, and torchvision that contain #3535, #3541, and tested the speed again on a v3-128 pod. It seems that the speed regression issue persists in this nightly 20220430 version.

Actually, the speed regression in the case above seems worse in the nightly 20220430 torch_xla wheel (used along with libtpu-nightly==0.1.dev20220413) compared to the previous nightly 20220308 wheels (when this issue was originally raised), as detailed below.


I'm running it as follows using a v3-128 TPU VM ptxla-dev-128-1 and starting with the tpu-vm-pt-1.10 runtime environment:

  1. Installing timm 0.4.12, the nightly 20220430 wheels, and libtpu 20220413.
TPU_NAME=ptxla-dev-128-1
ZONE=europe-west4-a

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} \
  --worker all \
  --command "sudo pip3 install timm==0.4.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl"
  1. Downloading the code above to /checkpoint/ronghanghu/workspace/mae_tpu_speed_test_20220430 on all VM nodes.

  2. Run it on the v3-128 pod as follows.

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu_speed_test_20220430  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-128_libtpu_20220413
MODEL=mae_vit_large_patch16
EPOCH=20
BATCH_SIZE_PER_TPU=32

TPU_NAME=ptxla-dev-128-1

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod-server \
  --env XRT_MESH_CONNECT_WAIT=1200 --env PYTHONUNBUFFERED=1 -- \
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

I'm getting the following metrics at the end (after 20 epochs)

2022-04-30 17:47:30 10.164.1.109 [0] Metric: ExecuteTime                                                                                                                                                                                           
2022-04-30 17:47:30 10.164.1.109 [0]   TotalSamples: 6300                                                                                                                                                                                          
2022-04-30 17:47:30 10.164.1.109 [0]   Accumulator: 01h16m13s119ms677.065us                                                                                                                                                                        
2022-04-30 17:47:30 10.164.1.109 [0]   ValueRate: 965ms203.585us / second                                                                                                                                                                          
2022-04-30 17:47:30 10.164.1.109 [0]   Rate: 1.42488 / second                                                                                                                                                                                      
2022-04-30 17:47:30 10.164.1.109 [0]   Percentiles: 1%=094ms309.965us; 5%=623ms090.260us; 10%=624ms482.893us; 20%=626ms320.719us; 50%=673ms548.135us; 80%=735ms576.942us; 90%=753ms104.198us; 95%=824ms024.495us; 99%=866ms028.639us

So it shows that the 50% percentile increased from 350ms108.598us (in tpu-vm-pt-1.10) to 673ms548.135us. This seems even worse than the 469ms870.575us metric in nightly 20220308 wheels. (Also note that data loading is not a bottleneck here as we are using fake data for this speed test.)

@yeounoh Could you take a further look at this speed issue? Thanks a lot!

@JackCaoG We also observe that the nightly 20220430 wheels might break the gradient checkpointing example (#3524 (comment)). I wonder could it be related to this speed regression issue here? Gradient checkpointing is still working well in the nightly 20220430 wheels despite less TPU free memory at the end of each iteration in #3524 (comment)

@yeounoh
Copy link
Contributor

yeounoh commented May 2, 2022

Hi @ronghanghu sorry to hear, will take a look. Just in case you've already tried, does this appear on v3-8 as well, like the previous regression?

@ronghanghu
Copy link
Collaborator Author

@yeounoh Thanks again! I haven't retried on v3-8 yet (our real use case was on TPU pod). I can probably take another look at it this evening.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented May 3, 2022

Update: the v3-8 speed also got much slower with nightly 20220430 torch, torch_xla, and torchvision wheels than the previous tpu-vm-pt-1.10 runtime.

Metric: ExecuteTime                                                                                                                                                                                                                                
  TotalSamples: 55077                                                                                                                                                                                                                              
  Accumulator: 10h02m14s272ms366.865us                                                                                                                                                                                                             
  ValueRate: 967ms941.605us / second                                                                                                                                                                                                               
  Rate: 1.5071 / second                                                                                                                                                                                                                            
  Percentiles: 1%=621ms978.636us; 5%=624ms824.197us; 10%=625ms381.422us; 20%=628ms290.496us; 50%=644ms827.592us; 80%=650ms305.695us; 90%=656ms844.635us; 95%=663ms354.183us; 99%=688ms185.173us

The 50% percentile of ExecuteTime on v3-8 is now as long as 644ms827.592us. This is now a lot worse than the corresponding metric of 348ms614.947us in tpu-vm-pt-1.10 measured previously, almost doubling the execution time.


It's measured as follows (running for 10 epochs):

# Installing timm 0.4.12, the nightly 20220430 wheels, and libtpu 20220413
sudo pip3 install timm==0.4.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu_speed_test_20220430  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_20220413
MODEL=mae_vit_large_patch16
EPOCH=10
BATCH_SIZE_PER_TPU=32

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

@ronghanghu
Copy link
Collaborator Author

A minor update: given that unset LD_PRELOAD could cause a memory leak in #3545 (comment), I tried running the v3-8 speed test (under the nightly 20220430 wheels, and libtpu 20220413) with either export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 or unset LD_PRELOAD.

In both cases, the 50% percentile of ExecuteTime on v3-8 is 630+ ms and much slower than the tpu-vm-pt-1.10 environment (around 348ms614.947us) as detailed below, so tcmalloc vs standard C++ malloc is not the cause of this performance regression.


With tcmalloc

# Installing timm 0.4.12, the nightly 20220430 wheels, and libtpu 20220413
sudo pip3 install timm==0.4.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

# With `tcmalloc`
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu_speed_test_20220430  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_20220413_with_LD_PRELOAD
MODEL=mae_vit_large_patch16
EPOCH=1
BATCH_SIZE_PER_TPU=32

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

prints

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 53m21s514ms053.723us
  ValueRate: 964ms783.489us / second
  Rate: 1.51582 / second
  Percentiles: 1%=617ms472.659us; 5%=620ms077.718us; 10%=621ms122.524us; 20%=624ms037.561us; 50%=636ms416.594us; 80%=645ms688.789us; 90%=653ms629.758us; 95%=662ms907.008us; 99%=675ms135.820us

With standard C++ malloc

# Installing timm 0.4.12, the nightly 20220430 wheels, and libtpu 20220413
sudo pip3 install timm==0.4.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

# With standard C++ `malloc`
unset LD_PRELOAD

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu_speed_test_20220430  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_20220413_without_LD_PRELOAD
MODEL=mae_vit_large_patch16
EPOCH=1
BATCH_SIZE_PER_TPU=32

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

prints

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 54m30s345ms238.447us
  ValueRate: 957ms309.366us / second
  Rate: 1.50187 / second
  Percentiles: 1%=618ms006.271us; 5%=621ms547.657us; 10%=623ms666.277us; 20%=626ms417.864us; 50%=639ms016.085us; 80%=646ms705.615us; 90%=650ms150.774us; 95%=657ms915.178us; 99%=682ms412.205us

@yeounoh
Copy link
Contributor

yeounoh commented May 10, 2022

Hi @ronghanghu , wanted to give a quick update.

I was able to reproduce the (worsened) regression using the same torch_xla-nightly+20220430 and libtpu-nightly-0.1.dev20220413:

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 48m09s643ms910.710us
  ValueRate: 951ms889.405us / second
  Rate: 1.6492 / second
  Percentiles: 1%=571ms955.587us; 5%=572ms423.067us; 10%=573ms264.672us; 20%=574ms491.058us; 50%=577ms035.262us; 80%=580ms859.434us; 90%=582ms093.667us; 95%=586ms282.053us; 99%=594ms950.517us

Now, I took the torch_xla 1.11 release branch and applied a few manual patches and ran again with the same libtpu-nightly-0.1.dev20220413,

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 30m31s803ms248.815us
  ValueRate: 909ms161.580us / second
  Rate: 2.58249 / second
  Percentiles: 1%=339ms896.135us; 5%=341ms681.130us; 10%=342ms062.884us; 20%=344ms755.787us; 50%=348ms338.145us; 80%=360ms771.500us; 90%=371ms760.170us; 95%=377ms932.335us; 99%=403ms763.110us

The same libtpu (with the all_reduce path) runs just fine (this is the performance we observed with the older libtpu with torch_xla==1.10 before regression).

I think some of the recent changes (>1.11 release) is causing this more severe regression, we will continue investigate and fix the nightly regression (you don't have to fall back to 1.11).

@ronghanghu
Copy link
Collaborator Author

@yeounoh I see, thanks for the update!

@yeounoh
Copy link
Contributor

yeounoh commented May 13, 2022

Hi @ronghanghu a quick update -- we were able to track down a bug in our torch.gather implementation where sparse_grad=true is forced. This shouldn't be always the case. I will tag you in a relevant PR to fix it; meanwhile,

this is the new nightly number after the patch:

Metric: ExecuteTime
  TotalSamples: 5007
  Accumulator: 31m07s735ms512.868us
  ValueRate: 330ms718.486us / second
  Rate: 0.885926 / second
  Percentiles: 1%=364ms303.843us; 5%=366ms200.505us; 10%=367ms443.119us; 20%=369ms558.663us; 50%=371ms186.603us; 80%=375ms360.225us; 90%=380ms915.686us; 95%=386ms566.346us; 99%=396ms800.193us

There is still +/- 20ms gap from the baseline (torch_xla==1.11 with libtpu-nightly-0.1.dev20220413), we will continue investigate.

@ronghanghu
Copy link
Collaborator Author

Thanks, @yeounoh, this is good to know and the new numbers look much better. Looking forward to the fix!

@ronghanghu
Copy link
Collaborator Author

Following up on this: I tried out the nightly 20220518 wheels and confirm the speed got much better as @yeounoh reported above after #3566 is merged.

@yeounoh
Copy link
Contributor

yeounoh commented May 18, 2022

That's great, @ronghanghu thank you for conforming. I was able to use the 1.11 release branch, apply some patches, and run with the libtpu-nightly-0.1.dev20220413 to match the 1.10 performance. However, I am still struggling to close out the remaining gap between 1.10, 1.11 (~350ms @50p, v3-8) and the current nightly (~370ms @50p, v3-8). We will continue and let you know when we fix it.

@ronghanghu
Copy link
Collaborator Author

Sounds great, thanks @yeounoh!

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Aug 17, 2022

Hi @yeounoh, this example above (on running MAE on TPUs) seemed to stop working on PyTorch/XLA 1.12 on TPU pods. It hangs on the first iteration and crashes after a while. Do you know if there is any recent change (particularly perhaps on libtpu) that could break things here?

When I try running it from a v3-128 TPU VM (e.g. "debug-v3-128" below) from the tpu-vm-pt-1.12 environment, it hangs on the first iteration and crashes after a while.

On the other hand, it could still work on a v3-8 under tpu-vm-pt-1.12 environment. (And previously it worked on both v3-8 and v3-128 on tpu-vm-pt-1.10 and earlier nightly environments as mentioned in the threads above.)


To reproduce it:

  1. Allocate a v3-128 TPU VM pod from the tpu-vm-pt-1.12 environment.
  2. Install the timm library via sudo pip3 install timm==0.4.12 and the CPU version of TensorFlow (to resolve torch.utils.tensorboard (or tensorflow in general) cannot be imported/used under PT/XLA 1.12 (tpu-vm-pt-1.12) #3786) on all VM nodes (e.g. through gcloud alpha compute tpus tpu-vm ssh --worker all):
TPU_NAME=debug-v3-128  # change to your TPU name
ZONE=europe-west4-a  # change to your TPU zone
PROJECT_ID=fair-infra3f4ebfe6  # change to your GCP project id

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID} --worker all \
  --command "
sudo pip3 install timm==0.4.12
# to resolve https://github.com/pytorch/xla/issues/3786
sudo pip3 uninstall -y tensorflow && sudo pip3 install tensorflow-cpu==2.9.1  # 2.9.1 is the latest version of "tensorflow-cpu"
"
  1. Download the codebase in https://drive.google.com/file/d/1-iXaPkiHr3c4nTqKx8c6zK1cx5ba_T0p/view?usp=sharing to all TPU VM nodes (e.g. put it under /checkpoint/ronghanghu/workspace/mae_tpu).
  2. Then, launch it on the v3-128 TPU VM ("debug-v3-128" below) as follows:
MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-128_libtpu_pt120
MODEL=mae_vit_large_patch16
EPOCH=20
BATCH_SIZE_PER_TPU=32

TPU_NAME=debug-v3-128  # change to your TPU name

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod-server \
  --env XRT_MESH_CONNECT_WAIT=1200 --env PYTHONUNBUFFERED=1 -- \
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

This hangs on the 1st iteration on a v3-128 pod, eventually printing an error of Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings".

...
2022-08-17 21:56:05 10.164.1.59 [0] [21:56:05.215938] Start training for 20 epochs
2022-08-17 21:56:05 10.164.1.59 [0] [21:56:05.218035] log_dir: /home/ronghanghu/vitl_mae_debug_fakedata_v3-128_libtpu_pt120
2022-08-17 22:21:58 10.164.0.233 [14] E0817 22:21:58.905253845   12953 chttp2_transport.cc:1098]   Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to "too_many_pings"
2022-08-17 22:21:58 10.164.0.233 [14] 2022-08-17 22:21:58.905600: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Too many pings" and grpc_error_string = "{"created":"@1660774918.905323917","description":"Error received from peer ipv4:10.164.0.233:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Too many pings","grpc_status":14}", maybe retrying the RPC

  1. On the other hand, it could work on a v3-8 TPU VM under the tpu-vm-pt-1.12 environment as follows:
sudo pip3 install timm==0.4.12
# to resolve https://github.com/pytorch/xla/issues/3786
sudo pip3 uninstall -y tensorflow && sudo pip3 install tensorflow-cpu==2.9.1  # 2.9.1 is the latest version of "tensorflow-cpu"

and

MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu  # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_pt120
MODEL=mae_vit_large_patch16
EPOCH=4
BATCH_SIZE_PER_TPU=32

sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
    --output_dir ${SAVE_DIR} \
    --log_dir ${SAVE_DIR} \
    --batch_size ${BATCH_SIZE_PER_TPU} \
    --model ${MODEL} \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs ${EPOCH} \
    --warmup_epochs 40 \
    --blr 0.0 --weight_decay 0.05 \
    --fake_data \
    --num_workers 8 \
    --use_xla \
    2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log

which runs well and prints

...
[20:37:27.997808] Start training for 4 epochs
[20:37:28.001358] log_dir: /home/ronghanghu/vitl_mae_debug_fakedata_v3-8_libtpu_pt120
[20:37:32.826325] Epoch: [0]  [   0/5004]  eta: 6:38:12  lr: 0.000000  time: 4.7747  data: 4.3553
[20:45:33.766382] Epoch: [0]  [  20/5004]  eta: 1 day, 8:01:16  lr: 0.000000  loss: 0.7176 (0.7176)  time: 24.0470  data: 23.9321
[20:45:41.944633] Epoch: [0]  [  40/5004]  eta: 16:36:36  lr: 0.000000  loss: 0.7176 (0.7176)  time: 0.4089  data: 0.2924
[20:45:50.304374] Epoch: [0]  [  60/5004]  eta: 11:18:26  lr: 0.000000  loss: 0.7176 (0.7177)  time: 0.4180  data: 0.3025

@JackCaoG
Copy link
Collaborator

@ronghanghu Can you share the server log under /tmp/xrt_server_log?

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Aug 17, 2022

@JackCaoG I uploaded the XRT server log (on TPU VM worker 0 for the case of v3-128 pod) to https://gist.github.com/ronghanghu/36d0c29085d517e181e6edd460c6dfd8#file-server_20220817-215525-log. I guess this new issue is most likely because of some changes in libtpu.

(And this gist above also has the stdout and stderr output file from the training command on v3-128 above.)

@JackCaoG
Copy link
Collaborator

2022-08-17 21:56:59.546433: I tensorflow/compiler/xrt/kernels/tpu_execute_op.cc:271] XRTExecuteOp::Compute
https://symbolize.stripped_domain/r/?trace=7ff716876cd7,7ff7168d308f&map= 
*** SIGTERM received by PID 13084 (TID 13084) on cpu 32 from PID 28473; stack trace: ***
PC: @     0x7ff716876cd7  (unknown)  __pthread_clockjoin_ex
    @     0x7ff636a8f2f3        992  (unknown)
    @     0x7ff7168d3090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7ff716876cd7,7ff636a8f2f2,7ff7168d308f&map=5920735bb186a93a82e10840f91bc184:7ff62229c000-7ff636e12fb0 
E0817 22:28:02.291921   13084 coredump_hook.cc:320] RAW: Remote crash gathering disabled for SIGTERM.
E0817 22:28:02.835050   13084 process_state.cc:774] RAW: Raising signal 15 with default behavior

it seems like this execution hangs for half hour and then just timeout. Let's wait for #3899 to merge then we can try to use a newer pt/xla wheel and libtpu. We are also trying to bump it again to Aug libtpu this week. @wonjoolee95 FYI.

If we still hitting the same issue with latest libtpu I will file a bug and get someone from xla team to help.

@ronghanghu
Copy link
Collaborator Author

@JackCaoG Sounds good. The most recent PyTorch/XLA (and libtpu) would sometimes just hang on several models that could previously work under PT/XLA 1.10 and 1.11. I'll wait for #3899 and check out the nightly wheels afterward.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Aug 23, 2022

it seems like this execution hangs for half hour and then just timeout. Let's wait for #3899 to merge then we can try to use a newer pt/xla wheel and libtpu. We are also trying to bump it again to Aug libtpu this week. @wonjoolee95 FYI.

@JackCaoG Following up on this: with libtpu_nightly-0.1.dev20220623 (based on #3899) and the nightly 20220822 version of torch and torch_xla and torchvision, the jobs can be launched on v3-128 under this example above.

However, there is a quite serious speed regression (3X+ slower) for this example using PT/XLA nightly 20220822 + libtpu_nightly-0.1.dev20220623 in comparison with the PT/XLA 1.10 version, detailed below.


Installing the new nightly environment as follows:

# to resolve https://github.com/pytorch/xla/issues/3786
sudo pip3 uninstall -y tensorflow && sudo pip3 install tensorflow-cpu==2.9.1  # 2.9.1 is the latest version of "tensorflow-cpu"

# torch, torchvision and torch_xla
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220822-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220822-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220822-cp38-cp38-linux_x86_64.whl
# libtpu
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220623-py3-none-any.whl

# dependencies
sudo pip3 install timm==0.4.12
sudo pip3 install numpy==1.23.0
"

and run profiling as described in this original issue, we have the results below.

On v3-128 pod

The v3-128 pod is the practical use case and suffers the most from this issue. Below are the metrics (running with fake data).

  • PT/XLA 1.10 + libtpu from tpu-vm-pt-1.10 runtime
2022-03-23 16:22:20 10.164.1.16 [0] Metric: ExecuteTime
2022-03-23 16:22:20 10.164.1.16 [0]   TotalSamples: 6300
2022-03-23 16:22:20 10.164.1.16 [0]   Accumulator: 45m58s967ms609.476us
2022-03-23 16:22:20 10.164.1.16 [0]   ValueRate: 921ms981.695us / second
2022-03-23 16:22:20 10.164.1.16 [0]   Rate: 2.43071 / second
2022-03-23 16:22:20 10.164.1.16 [0]   Percentiles: 1%=032ms627.252us; 5%=342ms396.858us; 10%=346ms541.697us; 20%=347ms044.893us; 50%=350ms108.598us; 80%=404ms203.194us; 90%=448ms869.785us; 95%=489ms928.152us; 99%=580ms060.833us
  • PT/XLA nightly 20220822 + libtpu_nightly-0.1.dev20220623
2022-08-22 23:50:19 10.164.2.29 [0] Metric: ExecuteTime
2022-08-22 23:50:19 10.164.2.29 [0]   TotalSamples: 6300
2022-08-22 23:50:19 10.164.2.29 [0]   Accumulator: 59m10s285ms312.600us
2022-08-22 23:50:19 10.164.2.29 [0]   ValueRate: 373ms959.425us / second
2022-08-22 23:50:19 10.164.2.29 [0]   Rate: 0.659334 / second
2022-08-22 23:50:19 10.164.2.29 [0]   Percentiles: 1%=073ms081.885us; 5%=500ms524.627us; 10%=511ms742.392us; 20%=523ms706.064us; 50%=552ms560.581us; 80%=592ms934.282us; 90%=637ms361.024us; 95%=735ms499.771us; 99%=897ms209.947us

Here the 50% percentile increased from 350ms108.598us to 552ms560.581us and the rate dropped from 2.43071 / second to 0.659334 / second (3.7X slower).

On v3-8

The v3-8 pod also suffers from a similar drop. Below are the metrics (running with fake data) under the same per-TPU-core batch size as in v3-128.

  • PT/XLA 1.10 + libtpu from tpu-vm-pt-1.10 runtime
Metric: ExecuteTime
  TotalSamples: 20028
  Accumulator: 02h05m49s508ms836.695us
  ValueRate: 898ms831.846us / second
  Rate: 2.57446 / second
  Percentiles: 1%=331ms520.482us; 5%=333ms254.737us; 10%=336ms635.889us; 20%=339ms375.811us; 50%=348ms614.947us; 80%=359ms644.435us; 90%=364ms916.825us; 95%=368ms475.100us; 99%=388ms471.869us
  • PT/XLA nightly 20220822 + libtpu_nightly-0.1.dev20220623
Metric: ExecuteTime
  TotalSamples: 20028
  Accumulator: 06h14m53s267ms033.059us
  ValueRate: 932ms441.766us / second
  Rate: 0.822482 / second
  Percentiles: 1%=666ms934.011us; 5%=881ms291.103us; 10%=01s024ms694.430us; 20%=01s048ms862.406us; 50%=01s084ms259.627us; 80%=01s295ms712.216us; 90%=01s403ms335.924us; 95%=01s488ms172.983us; 99%=02s548ms072.847us

Here the 50% percentile increased from 348ms614.947us to 01s084ms259.627us and the rate dropped from 2.57446 / second to 0.822482 / second (3.1X slower). And it is also weird that the 50% percentile of ExecuteTime is much slower on v3-8 in this case (despite the same per-TPU batch size), although the final rate is still slightly higher on v3-8.

Such 3X+ slower speed indicates something is wrong with the current system. We plan to include this example in our profiling benchmarks (that we will set up in both JAX and PT/XLA) and see if we can figure out a clue on why there is such a speed drop from PT/XLA 1.10 to the current version.

@JackCaoG
Copy link
Collaborator

Thanks @ronghanghu Can you open a new gihtub issue. We are working on the libtpu update again and hopefully can get it out by the end of this week. We will see if that solve the issue you are seeing. If not we will work XLA team to fix the regression. @wonjoolee95 FYI.

@yeounoh
Copy link
Contributor

yeounoh commented Sep 2, 2022

That's great, @ronghanghu thank you for conforming. I was able to use the 1.11 release branch, apply some patches, and run with the libtpu-nightly-0.1.dev20220413 to match the 1.10 performance. However, I am still struggling to close out the remaining gap between 1.10, 1.11 (~350ms @50p, v3-8) and the current nightly (~370ms @50p, v3-8). We will continue and let you know when we fix it.

Closing on the previous regression discussion, the ~6% speed regression in the May nightly -- we've landed pytorch/pytorch#84503 in PyTorch to address the issue, and the patched May nightly was able to match the performance of 1.10. We should be able to focus on the new changes since May for the current regression investigation.

cc @ronghanghu if you are opening up a new issue, you can close this one.

@ronghanghu
Copy link
Collaborator Author

Thanks @yeounoh for the investigation and the fix!

And for the recent issue we saw above in #3441 (comment), we will wait until the new TensorFlow pin (#3922) is landed and test again, and submit a new issue if we still see performance problems on this example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants