-
Notifications
You must be signed in to change notification settings - Fork 532
PyTorch GPU XLA 15% slower than Vanilla PyTorch (Imagenet Classification w/ R-50, Densenet etc.) #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
XLA shines when AMP is enabled. You should try enable fp16 for your model. |
Thank you, it is indeed faster now by 27% 🙏 . Can you tell me (or link me) why AMP is a huge factor for improvement in XLA ? |
XLA does kernel fusion to reduce memory access footprint. Float16 computation utilizes tensor core: these ops are so fast that memory bandwidth becomes bottleneck. |
Thanks 😊 |
Reopening this, it was faster on Tesla T4s, however the exact same model with same parameters and CUDA versions is slower by 16% on V100, Densenet is even slower (by 30%) and overall the numbers look pretty bad for XLA on V100. Is this expected ? |
I suspect that it's an environment issue. Mind to post your env here? If we see raw performance numbers it would be easier to help. |
Sure,
|
Do you have numbers for T4? It should be about 1/3 of V100 w/ AMP and XLA. |
Yes its 305 for w/XLA and 240 for Vanilla PT. |
I do see different CUDA versions for T4 and V100. XLA embedded in TF never shipped a 10.2 version..mind to use 10.1 for XLA? |
The default container at gcr.io/tpu-pytorch/xla:nightly_3.6_cuda supports CUDA 10.2 , therefore I decided to use 10.2 for the corresponding PT comparison as well. For the T4 comparison I was actually using an older container which had CUDA 10.1 |
Could you try using the same 10.1 container on V100? That certainly helps. |
Just tried it using 10.1. The performance numbers are the same for both p3.2xlarge and p3.16xlarge . No difference at all (i.e: < 1% difference, CUDA 10.1 being the slower one across Resnet and Densenet) |
I've been working on XLA support for some of my code and ended up trying the XLA + GPU + AMP on nightly docker containers. I've noticed some gap in performance too, varies quite a bit based on model arch. For classic resnet I only see approx 8% lower on Titan RTX PyTorch XLA GPU + AMP on 3.7 nightly docker vs a comparable PyTorch 10.2 cuda conda env. I'm seeing a bigger gap w/ larger batch sizes in distributed GPU training but that could be a bug in my code... |
I am suspecting tensorflow/tensorflow#44985 will help, will try to get it merge and update tf for pt/xla. |
@JackCaoG Ok, that seems to be a large improvement of almost 76%. The PR seems to have been opened last year so let me know if there is anything I could help with to speed up the process. |
@codeislife99 these are all on Titan RTX w/ AMP enabled in both pytorch-3.7-cuda nightly docker from a day ago and pytorch 1.7.1 10.2 (no different from 1.8.x 10.2 builds when tested on the same model before). There are significant gains running with optimized cudnn kernels in latest NGC builds though... but I haven't tried building a custom PyTorch XLA to see if SM8.6 works for my 3090 (diff machine) or if there is better NHWC support for Titan RTX (it works in JAX when I build jaxlib custom in NGC docker) Single GPU, batch_size 128 2x GPU distributed training |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Bug
PyTorch GPU XLA seems to be 15% slower than Native PyTorch with the official CUDA XLA containers at gcr.io.
To Reproduce
Steps to reproduce the behavior:
GPU_NUM_DEVICES=1 /pytorch/xla/test/test_train_mp_imagenet.py
inside the publicly available containercuda
from script in (1) and optimizer to torch optimizer to run non-xla(vanilla PT) version of trainingdensenet121
and repeat (1) and (2).Both models are 15-18% slower in XLA GPU than in Vanilla PyTorch CUDA
Expected behavior
Expected Vanilla PyTorch to be slower or atmost equivalent to PyTorch XLA GPU Performance for simple static models like resnet/ densenet
Environment
The text was updated successfully, but these errors were encountered: