Skip to content

Add DDP to WMT + faster ImageNet data loading #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 29, 2022

Conversation

runame
Copy link
Contributor

@runame runame commented Jun 21, 2022

Follow-up PR to #81.

This PR

  • adds DDP to WMT PyTorch workload,
  • fixes bug in WMT Jax workload,
  • improves the data loading for ImageNet PyTorch workload,
  • adds instructions for running DDP to README.

One question specifically for @mikerabbat: I have implemented the distributed sampling for the TF dataset by simply manually sharding each global batch across devices, see this wrapper which assumes this function has been mapped on the dataset before. I think you mentioned at some point that there is an issue with this, but I'm not sure what it was.

And one question @znado might be able to help with: when running the DDP WMT PyTorch workload with 8xV100s instead of 4xV100s, I'm getting this error:

Check failed: ret == 0 (11 vs. 0)Thread tf_pjrt_thread_pool creation via pthread_create() failed.

I think this is caused by the fact that the TF input pipeline is created for each process separately, i.e. as many times as there are GPUs. The error seems to indicate that this exhausts the resources, see this explanation. I tried setting AUTOTUNE = None (here) and increasing the number of threads which are allowed per process (using torch.set_num_threads(N) with N up to 8; the default ist N=1), but both didn't help.

@github-actions
Copy link

github-actions bot commented Jun 21, 2022

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@znado
Copy link
Contributor

znado commented Jun 22, 2022

I would have hoped that setting AUTOTUNE = None would have fixed that out of threads issue. Beyond the number of threads allowed per process, do you know if the host machines you're running on have a low number of max threads allowed (I could see that being the case on some clusters to avoid accidental fork bombs)? I'm not sure the most relevant place to check, but I know there are some places you can check at the Linux/system level (you may already have, just checking!).

As for sharding the global batch across GPUs like you do, would it instead be more efficient to run each copy of the input pipeline with the per-GPU batch size? Then you don't need to slice the batch (like you do here), you would just need to add a toggleable flag to the input pipeline to not shard the batches. Maybe this would help avoid memory issues (although it looks like the resource issue you're having is with threading, not memory?) To make sure you get a unique batch on each process, you can (hopefully?) fold the process rank into the seed, for example you could add this here:

data_rng = jax.random.fold_in(data_rng, torch.distributed.get_rank())
np_iter = super().build_input_queue(data_rng, ...)

Copy link
Contributor

@znado znado left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one question, otherwise this is fantastic, thanks so much for all these fixes!! it's awesome to get the pytorch pipelines sped up

runame added 2 commits June 23, 2022 14:20
Conflicts:
	reference_submissions/imagenet_vit/imagenet_pytorch/submission.py
@runame
Copy link
Contributor Author

runame commented Jun 28, 2022

I would have hoped that setting AUTOTUNE = None would have fixed that out of threads issue. Beyond the number of threads allowed per process, do you know if the host machines you're running on have a low number of max threads allowed (I could see that being the case on some clusters to avoid accidental fork bombs)? I'm not sure the most relevant place to check, but I know there are some places you can check at the Linux/system level (you may already have, just checking!).

As for sharding the global batch across GPUs like you do, would it instead be more efficient to run each copy of the input pipeline with the per-GPU batch size? Then you don't need to slice the batch (like you do here), you would just need to add a toggleable flag to the input pipeline to not shard the batches. Maybe this would help avoid memory issues (although it looks like the resource issue you're having is with threading, not memory?) To make sure you get a unique batch on each process, you can (hopefully?) fold the process rank into the seed, for example you could add this here:

I haven't looked much further into the threads issue, but will try again later. PyTorch limits the threads per process to 1 by default when using DDP (OMP_NUM_THREADS=1), but increasing this didn't help -- I'll double check that there is no limit on the hardware/system level. In any case, I think it's fine to merge this PR before this is fixed.

Regarding a more efficient sharding strategy: I also briefly thought about this, but unless I'm missing something there might be examples appearing multiple times in the same global batch with your suggested approach, since even though the local batches won't exactly be the same due to the different random seeds, all processes still have access to the full dataset. When only passing a subset of the dataset to each process, the issue is that the shuffling will be biased, since there cannot be > per_device_batch_size samples from one of the subsets in the same global batch. Hence why I chose the simple but inefficient approach.

@znado
Copy link
Contributor

znado commented Jun 29, 2022

Yeah that's a good point regarding possibly repeated examples. A solution we've used to that is sharding the input files across processes, so you guarantee you have different examples per process (and then you don't need to mess with the RNGs).

@znado znado merged commit ffdfc08 into mlcommons:main Jun 29, 2022
@github-actions github-actions bot locked and limited conversation to collaborators Jun 29, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants