Skip to content

PyTorch GPU XLA 15% slower than Vanilla PyTorch (Imagenet Classification w/ R-50, Densenet etc.) #2901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
codeislife99 opened this issue Apr 20, 2021 · 18 comments
Labels
stale Has not had recent activity xla:gpu

Comments

@codeislife99
Copy link

codeislife99 commented Apr 20, 2021

🐛 Bug

PyTorch GPU XLA seems to be 15% slower than Native PyTorch with the official CUDA XLA containers at gcr.io.

To Reproduce

Steps to reproduce the behavior:

  1. Run GPU_NUM_DEVICES=1 /pytorch/xla/test/test_train_mp_imagenet.py inside the publicly available container
  2. Change device to cuda from script in (1) and optimizer to torch optimizer to run non-xla(vanilla PT) version of training
  3. Change model to densenet121 and repeat (1) and (2).

Both models are 15-18% slower in XLA GPU than in Vanilla PyTorch CUDA

Expected behavior

Expected Vanilla PyTorch to be slower or atmost equivalent to PyTorch XLA GPU Performance for simple static models like resnet/ densenet

Environment

  • Tesla T4 1 GPU , CUDA 10.1
  • torch_xla version: 1.9+3925f6e
@byronyi
Copy link
Contributor

byronyi commented Apr 20, 2021

XLA shines when AMP is enabled. You should try enable fp16 for your model.

@codeislife99
Copy link
Author

Thank you, it is indeed faster now by 27% 🙏 . Can you tell me (or link me) why AMP is a huge factor for improvement in XLA ?

@byronyi
Copy link
Contributor

byronyi commented Apr 20, 2021

XLA does kernel fusion to reduce memory access footprint. Float16 computation utilizes tensor core: these ops are so fast that memory bandwidth becomes bottleneck.

@codeislife99
Copy link
Author

Thanks 😊

@codeislife99
Copy link
Author

codeislife99 commented Apr 21, 2021

Reopening this, it was faster on Tesla T4s, however the exact same model with same parameters and CUDA versions is slower by 16% on V100, Densenet is even slower (by 30%) and overall the numbers look pretty bad for XLA on V100. Is this expected ?

@byronyi
Copy link
Contributor

byronyi commented Apr 21, 2021

I suspect that it's an environment issue. Mind to post your env here? If we see raw performance numbers it would be easier to help.

@codeislife99
Copy link
Author

codeislife99 commented Apr 21, 2021

Sure,
Env: Tesla V100, CUDA 10.2 , batch_size 128, AWS - p3.2xlarge, Ubuntu18.04 Deep Learning AMI , script here. The following table summarizes my performance evaluation

Model # GPUs Instance Batch Size Other relevant variables w/ XLA Amp Throughput Vanilla PT Amp Throughput Speedup Notes
Resnet-50 1 p3.2xlarge 128 syn data, CUDA 10.2 634 750 0.84533 P3s seem to be slower than G4s
Resnet-50 1 p3.16xlarge 128 syn data, CUDA 10.2 570 750 0.76 On a larger P3 the performance loss seems to be higher

@byronyi
Copy link
Contributor

byronyi commented Apr 21, 2021

Do you have numbers for T4? It should be about 1/3 of V100 w/ AMP and XLA.

@codeislife99
Copy link
Author

Yes its 305 for w/XLA and 240 for Vanilla PT.

@byronyi
Copy link
Contributor

byronyi commented Apr 21, 2021

I do see different CUDA versions for T4 and V100. XLA embedded in TF never shipped a 10.2 version..mind to use 10.1 for XLA?

@codeislife99
Copy link
Author

codeislife99 commented Apr 21, 2021

The default container at gcr.io/tpu-pytorch/xla:nightly_3.6_cuda supports CUDA 10.2 , therefore I decided to use 10.2 for the corresponding PT comparison as well. For the T4 comparison I was actually using an older container which had CUDA 10.1
Another thing to note is that, both the CUDA 10.1 and 10.2 XLA containers had the same performance on Tesla T4, so I think we can probably eliminate CUDA Version as the reason.

@byronyi
Copy link
Contributor

byronyi commented Apr 21, 2021

Could you try using the same 10.1 container on V100? That certainly helps.

@codeislife99
Copy link
Author

codeislife99 commented Apr 21, 2021

Just tried it using 10.1. The performance numbers are the same for both p3.2xlarge and p3.16xlarge . No difference at all (i.e: < 1% difference, CUDA 10.1 being the slower one across Resnet and Densenet)

@rwightman
Copy link

I've been working on XLA support for some of my code and ended up trying the XLA + GPU + AMP on nightly docker containers. I've noticed some gap in performance too, varies quite a bit based on model arch. For classic resnet I only see approx 8% lower on Titan RTX PyTorch XLA GPU + AMP on 3.7 nightly docker vs a comparable PyTorch 10.2 cuda conda env.

I'm seeing a bigger gap w/ larger batch sizes in distributed GPU training but that could be a bug in my code...

@JackCaoG
Copy link
Collaborator

I am suspecting tensorflow/tensorflow#44985 will help, will try to get it merge and update tf for pt/xla.

@codeislife99
Copy link
Author

@JackCaoG Ok, that seems to be a large improvement of almost 76%. The PR seems to have been opened last year so let me know if there is anything I could help with to speed up the process.
@rwightman I have made the same observations as you on all fronts (model arch, batch size, and multi-GPU setting). Do you mind posting your benchmark numbers here for the community or any other concrete observations you have made? Thanks :)

@rwightman
Copy link

rwightman commented Apr 21, 2021

@codeislife99 these are all on Titan RTX w/ AMP enabled in both pytorch-3.7-cuda nightly docker from a day ago and pytorch 1.7.1 10.2 (no different from 1.8.x 10.2 builds when tested on the same model before). There are significant gains running with optimized cudnn kernels in latest NGC builds though... but I haven't tried building a custom PyTorch XLA to see if SM8.6 works for my 3090 (diff machine) or if there is better NHWC support for Titan RTX (it works in JAX when I build jaxlib custom in NGC docker)

Single GPU, batch_size 128
Pytorch XLA: ~530img/s
PyTorch 1.7.1 CUDA 10.2: ~570img/s
PyTorch NGC 21.03 w/ cuda11.2, cudnn8 channels_last: 700+ img/s

2x GPU distributed training
PyTorch XLA = ~1000 img/s at batch 128 (per card), 850 img/sec at 256
PyTorch 1.7.1 > 1100 img/s

@stale
Copy link

stale bot commented Jun 11, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Jun 11, 2021
@stale stale bot closed this as completed Jun 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity xla:gpu
Projects
None yet
Development

No branches or pull requests

4 participants