-
Notifications
You must be signed in to change notification settings - Fork 72
Add DDP to WMT + faster ImageNet data loading #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
I would have hoped that setting As for sharding the global batch across GPUs like you do, would it instead be more efficient to run each copy of the input pipeline with the per-GPU batch size? Then you don't need to slice the batch (like you do here), you would just need to add a toggleable flag to the input pipeline to not shard the batches. Maybe this would help avoid memory issues (although it looks like the resource issue you're having is with threading, not memory?) To make sure you get a unique batch on each process, you can (hopefully?) fold the process rank into the seed, for example you could add this here:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just one question, otherwise this is fantastic, thanks so much for all these fixes!! it's awesome to get the pytorch pipelines sped up
Conflicts: reference_submissions/imagenet_vit/imagenet_pytorch/submission.py
I haven't looked much further into the threads issue, but will try again later. PyTorch limits the threads per process to Regarding a more efficient sharding strategy: I also briefly thought about this, but unless I'm missing something there might be examples appearing multiple times in the same global batch with your suggested approach, since even though the local batches won't exactly be the same due to the different random seeds, all processes still have access to the full dataset. When only passing a subset of the dataset to each process, the issue is that the shuffling will be biased, since there cannot be |
Yeah that's a good point regarding possibly repeated examples. A solution we've used to that is sharding the input files across processes, so you guarantee you have different examples per process (and then you don't need to mess with the RNGs). |
Follow-up PR to #81.
This PR
One question specifically for @mikerabbat: I have implemented the distributed sampling for the TF dataset by simply manually sharding each global batch across devices, see this wrapper which assumes this function has been mapped on the dataset before. I think you mentioned at some point that there is an issue with this, but I'm not sure what it was.
And one question @znado might be able to help with: when running the DDP WMT PyTorch workload with 8xV100s instead of 4xV100s, I'm getting this error:
Check failed: ret == 0 (11 vs. 0)Thread tf_pjrt_thread_pool creation via pthread_create() failed.
I think this is caused by the fact that the TF input pipeline is created for each process separately, i.e. as many times as there are GPUs. The error seems to indicate that this exhausts the resources, see this explanation. I tried setting
AUTOTUNE = None
(here) and increasing the number of threads which are allowed per process (usingtorch.set_num_threads(N)
withN
up to 8; the default istN=1
), but both didn't help.