-
Notifications
You must be signed in to change notification settings - Fork 523
Major speed regression in the nightly libtpu (20220308) compared to the 1.10 environment #3441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@yeounoh Can you take a look? |
Hi @ronghanghu thanks for reporting the issue. I was able to reproduce and also confirm the issue with the latest |
@yeounoh That's awesome, thanks for your help! |
Hi @ronghanghu , we were able to address an issue with one of the collective ops and it seems to yield slightly better performance (~10us improvement in ExecuteTime at 50%). FWIW, if you'd like to try latest We will continue investigating and keep you posted. |
This is great to hear, thanks for the update, @yeounoh! |
Hi @ronghanghu, realized that you were having an issue with the libtpu-nightly. Let me also take a look... |
That's great, thank you! |
Hi @ronghanghu , we were able to address the regression -- this PR will land sometime today, you can pull the change and do Here is the result I got using
Would you be able to help us verify? |
Great, I'll check it out. Thank you! @yeounoh Just to double-check, to try out this new version, should I use the |
The tf is pin to 04/07 so I expect both |
Hi @ronghanghu we've resolved the build issue regarding the tensorflow update. Once we land #3535, #3541, you should be able to checkout the latest master and use it with |
@yeounoh Thanks! I'll check it out when these PRs are merged |
Update: following up on this -- I tried out the nightly 20220430 wheels for torch_xla, torch, and torchvision that contain #3535, #3541, and tested the speed again on a v3-128 pod. It seems that the speed regression issue persists in this nightly 20220430 version. Actually, the speed regression in the case above seems worse in the nightly 20220430 torch_xla wheel (used along with libtpu-nightly==0.1.dev20220413) compared to the previous nightly 20220308 wheels (when this issue was originally raised), as detailed below. I'm running it as follows using a v3-128 TPU VM
I'm getting the following metrics at the end (after 20 epochs)
So it shows that the 50% percentile increased from 350ms108.598us (in @yeounoh Could you take a further look at this speed issue? Thanks a lot!
|
Hi @ronghanghu sorry to hear, will take a look. Just in case you've already tried, does this appear on v3-8 as well, like the previous regression? |
@yeounoh Thanks again! I haven't retried on v3-8 yet (our real use case was on TPU pod). I can probably take another look at it this evening. |
Update: the v3-8 speed also got much slower with nightly 20220430
The 50% percentile of It's measured as follows (running for 10 epochs):
|
A minor update: given that In both cases, the 50% percentile of With
|
Hi @ronghanghu , wanted to give a quick update. I was able to reproduce the (worsened) regression using the same
Now, I took the torch_xla 1.11 release branch and applied a few manual patches and ran again with the same libtpu-nightly-0.1.dev20220413,
The same libtpu (with the all_reduce path) runs just fine (this is the performance we observed with the older libtpu with torch_xla==1.10 before regression). I think some of the recent changes (>1.11 release) is causing this more severe regression, we will continue investigate and fix the nightly regression (you don't have to fall back to 1.11). |
@yeounoh I see, thanks for the update! |
Hi @ronghanghu a quick update -- we were able to track down a bug in our this is the new nightly number after the patch:
There is still +/- 20ms gap from the baseline ( |
Thanks, @yeounoh, this is good to know and the new numbers look much better. Looking forward to the fix! |
That's great, @ronghanghu thank you for conforming. I was able to use the 1.11 release branch, apply some patches, and run with the |
Sounds great, thanks @yeounoh! |
Hi @yeounoh, this example above (on running MAE on TPUs) seemed to stop working on PyTorch/XLA 1.12 on TPU pods. It hangs on the first iteration and crashes after a while. Do you know if there is any recent change (particularly perhaps on libtpu) that could break things here? When I try running it from a v3-128 TPU VM (e.g. "debug-v3-128" below) from the On the other hand, it could still work on a v3-8 under To reproduce it:
TPU_NAME=debug-v3-128 # change to your TPU name
ZONE=europe-west4-a # change to your TPU zone
PROJECT_ID=fair-infra3f4ebfe6 # change to your GCP project id
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project ${PROJECT_ID} --worker all \
--command "
sudo pip3 install timm==0.4.12
# to resolve https://github.com/pytorch/xla/issues/3786
sudo pip3 uninstall -y tensorflow && sudo pip3 install tensorflow-cpu==2.9.1 # 2.9.1 is the latest version of "tensorflow-cpu"
"
MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-128_libtpu_pt120
MODEL=mae_vit_large_patch16
EPOCH=20
BATCH_SIZE_PER_TPU=32
TPU_NAME=debug-v3-128 # change to your TPU name
sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist \
--tpu=${TPU_NAME} --restart-tpuvm-pod-server \
--env XRT_MESH_CONNECT_WAIT=1200 --env PYTHONUNBUFFERED=1 -- \
python3 ${MAE_PATH}/main_pretrain.py \
--output_dir ${SAVE_DIR} \
--log_dir ${SAVE_DIR} \
--batch_size ${BATCH_SIZE_PER_TPU} \
--model ${MODEL} \
--norm_pix_loss \
--mask_ratio 0.75 \
--epochs ${EPOCH} \
--warmup_epochs 40 \
--blr 0.0 --weight_decay 0.05 \
--fake_data \
--num_workers 8 \
--use_xla \
2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log This hangs on the 1st iteration on a v3-128 pod, eventually printing an error of
sudo pip3 install timm==0.4.12
# to resolve https://github.com/pytorch/xla/issues/3786
sudo pip3 uninstall -y tensorflow && sudo pip3 install tensorflow-cpu==2.9.1 # 2.9.1 is the latest version of "tensorflow-cpu" and MAE_PATH=/checkpoint/ronghanghu/workspace/mae_tpu # where the code is downloaded above
SAVE_DIR=~/vitl_mae_debug_fakedata_v3-8_libtpu_pt120
MODEL=mae_vit_large_patch16
EPOCH=4
BATCH_SIZE_PER_TPU=32
sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR
python3 ${MAE_PATH}/main_pretrain.py \
--output_dir ${SAVE_DIR} \
--log_dir ${SAVE_DIR} \
--batch_size ${BATCH_SIZE_PER_TPU} \
--model ${MODEL} \
--norm_pix_loss \
--mask_ratio 0.75 \
--epochs ${EPOCH} \
--warmup_epochs 40 \
--blr 0.0 --weight_decay 0.05 \
--fake_data \
--num_workers 8 \
--use_xla \
2>&1 | tee $SAVE_DIR/stdout_stderr_$(date +%Y-%m-%d_%H-%M-%S).log which runs well and prints
|
@ronghanghu Can you share the server log under |
@JackCaoG I uploaded the XRT server log (on TPU VM worker 0 for the case of v3-128 pod) to https://gist.github.com/ronghanghu/36d0c29085d517e181e6edd460c6dfd8#file-server_20220817-215525-log. I guess this new issue is most likely because of some changes in libtpu. (And this gist above also has the stdout and stderr output file from the training command on v3-128 above.) |
it seems like this execution hangs for half hour and then just timeout. Let's wait for #3899 to merge then we can try to use a newer pt/xla wheel and libtpu. We are also trying to bump it again to Aug libtpu this week. @wonjoolee95 FYI. If we still hitting the same issue with latest libtpu I will file a bug and get someone from xla team to help. |
@JackCaoG Following up on this: with However, there is a quite serious speed regression (3X+ slower) for this example using PT/XLA nightly 20220822 + Installing the new nightly environment as follows:
and run profiling as described in this original issue, we have the results below. On v3-128 podThe v3-128 pod is the practical use case and suffers the most from this issue. Below are the metrics (running with fake data).
Here the 50% percentile increased from 350ms108.598us to 552ms560.581us and the rate dropped from 2.43071 / second to 0.659334 / second (3.7X slower). On v3-8The v3-8 pod also suffers from a similar drop. Below are the metrics (running with fake data) under the same per-TPU-core batch size as in v3-128.
Here the 50% percentile increased from 348ms614.947us to 01s084ms259.627us and the rate dropped from 2.57446 / second to 0.822482 / second (3.1X slower). And it is also weird that the 50% percentile of ExecuteTime is much slower on v3-8 in this case (despite the same per-TPU batch size), although the final rate is still slightly higher on v3-8. Such 3X+ slower speed indicates something is wrong with the current system. We plan to include this example in our profiling benchmarks (that we will set up in both JAX and PT/XLA) and see if we can figure out a clue on why there is such a speed drop from PT/XLA 1.10 to the current version. |
Thanks @ronghanghu Can you open a new gihtub issue. We are working on the libtpu update again and hopefully can get it out by the end of this week. We will see if that solve the issue you are seeing. If not we will work XLA team to fix the regression. @wonjoolee95 FYI. |
Closing on the previous regression discussion, the ~6% speed regression in the May nightly -- we've landed pytorch/pytorch#84503 in PyTorch to address the issue, and the patched May nightly was able to match the performance of 1.10. We should be able to focus on the new changes since May for the current regression investigation. cc @ronghanghu if you are opening up a new issue, you can close this one. |
Thanks @yeounoh for the investigation and the fix! And for the recent issue we saw above in #3441 (comment), we will wait until the new TensorFlow pin (#3922) is landed and test again, and submit a new issue if we still see performance problems on this example |
🐛 Bug
There is a notable speed regression in a few models with the nightly 20220308 PyTorch XLA builds. For example, in MAE ViT-L (code attached below) the training time increased from 0.46 s/iter to 0.61 s/iter on a v3-128 pod when changing to the nightly 20220308 version. after switching from tpu-vm-pt-1.10 to the 20220308 builds of torch, torchvision, torch_xla, and libtpu.
The speed issue is particularly related to libtpu: The speed regression above can be reproduced by installing
libtpu_nightly-0.1.dev20220308
on thetpu-vm-pt-1.10
runtime environment while keeping torch, torchvision, and torch_xla as-is fromtpu-vm-pt-1.10
. (As long aslibtpu_nightly-0.1.dev20220308
is used, the speed issue happens regardless of whether torch, torchvision, and torch_xla are thetpu-vm-pt-1.10
versions or the nightly 20220308 versions.)Increased
ExecuteTime
metric in libtpu nightly 20220308: When running the very same codebase against two libtpu versions (nightly 20220308 vstpu-vm-pt-1.10
), theExecuteTime
increased dramatically. This speed issue happens on both v3-128 TPU pod and v3-8.Code and the dumped HLO graphs are attached below to reproduce this speed drop.
On v3-128 pod
The v3-128 pod is the practical use case and suffers the most from this issue. Below are the metrics (running with fake data).
The 50% percentile increased from 350ms108.598us to 469ms870.575us (a major speed drop) and the rate dropped from 2.43071 / second to 1.89054 / second.
On v3-8
the v3-8 pod also suffers from a similar drop. Below are the metrics (running with fake data) under the same per-TPU-core batch size as in v3-128.
Here the 50% percentile increased from 348ms614.947us to 388ms388.937us (a smaller speed drop than on v3-128 but still quite notable).
The dumped HLO graph for the v3-8 case (using
XLA_SAVE_TENSORS_FILE=$hlo_$(date +%Y-%m-%d_%H-%M-%S).txt XLA_SAVE_TENSORS_FMT=hlo
; run only for 5 iterations to avoid saving too many outputs) as well as the stdout and stderr outputs containing the printed metrics above are uploaded to Google drive inTo Reproduce
tpu-vm-pt-1.10
environmentsudo pip3 install timm==0.4.12
on all VM nodes (e.g. throughgcloud alpha compute tpus tpu-vm ssh --worker all
)/checkpoint/ronghanghu/workspace/mae_tpu
)tpu-vm-pt-1.10
, don't do any changesudo pip3 uninstall -y libtpu_nightly && sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220308-py3-none-any.whl
on all VM nodesOn v3-128 pod
To run on a v3-128 pod (print metrics and save stdout and stderr; running with fake data):
The code will print the XLA metrics such as
ExecuteTime
after each training epoch.On v3-8
To run on a v3-8 (print metrics and save stdout and stderr; running with fake data):
The code will print the XLA metrics such as
ExecuteTime
after each training epoch.Expected behavior
The speed regression should not happen when switching to the nightly 20220308 version of libtpu.
Environment
tpu-vm-pt-1.10
(and install the nightly 20220308 version of libtpu)Additional context
The larger speed drop on a v3-128 pod suggests that there could be significant overhead when running the same model on a pod vs on a v3-8 under the same per-TPU batch size. This is somewhat unexpected: the model itself does not have any across TPU communications such as all_gather in it and the only ops involving cross TPU communication is the gradient reduction before applying optimizer updates). So I believe it would be great to investigate and resolve this overhead as it severely hurts many practical use cases.
cc @JackCaoG
The text was updated successfully, but these errors were encountered: