Skip to content

Add DDP to WMT + faster ImageNet data loading #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 29, 2022
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ env/
venv/
workdir/
makefile
*.out
*.sh
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Docker is the easiest way to enable PyTorch/JAX GPU support on Linux since only
```bash
python3 submission_runner.py \
--framework=jax \
--workload=mnist |\
--workload=mnist \
--submission_path=baselines/mnist/mnist_jax/submission.py \
--tuning_search_space=baselines/mnist/tuning_search_space.json
```
Expand All @@ -151,6 +151,14 @@ python3 submission_runner.py \
--tuning_search_space=baselines/mnist/tuning_search_space.json
```

When using multiple GPUs on a single node it is recommended to use PyTorch's
[distributed data parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
To do so, simply replace `python3` by
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=N_GPUS
```
where `N_GPUS` is the number of available GPUs on the node.

## Rules

The rules for the MLCommons Algorithmic Efficency benchmark can be found in the seperate [rules document](RULES.md). Suggestions, clarifications and questions can be raised via pull requests.
Expand Down
108 changes: 103 additions & 5 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import jax
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler
from torch.utils.data import Sampler


def shard_numpy_ds(xs):
"""Prepare tf data for JAX
"""Prepare tf data for JAX or PyTorch DDP.

Convert an input batch from tf Tensors to numpy arrays and reshape it to be
sharded across devices.
Expand All @@ -16,8 +19,9 @@ def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
x = x._numpy() # pylint: disable=protected-access

# reshape (host_batch_size, height, width, 3) to
# (local_devices, device_batch_size, height, width, 3)
# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1) + x.shape[1:])

return jax.tree_map(_prepare, xs)
Expand All @@ -33,7 +37,7 @@ def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False):
assert len(keys) == len(batch)
yield dict(zip(keys, batch))
except StopIteration:
if custom_sampler:
if custom_sampler and isinstance(iterable, DataLoader):
epoch += 1
iterable.sampler.set_epoch(epoch)
iterator = iter(iterable)
Expand All @@ -54,7 +58,7 @@ class DistributedEvalSampler(Sampler):
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class`~torch.utils.data.DistributedSampler` instance as
process can pass a :class`~DistributedEvalSampler` instance as
a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Expand Down Expand Up @@ -144,3 +148,97 @@ def set_epoch(self, epoch):
epoch (int): _epoch number.
"""
self.epoch = epoch


# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
# ConvNets/image_classification/dataloaders.py
def fast_collate(batch, memory_format=torch.contiguous_format):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros(
(len(imgs), 3, h, w),
dtype=torch.uint8).contiguous(memory_format=memory_format)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if nump_array.ndim < 3:
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array.copy())
return tensor, targets


# Inspired by
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
# ConvNets/image_classification/dataloaders.py
class PrefetchedWrapper:

def __init__(self, dataloader, device, mean, std, start_epoch=0):
self.dataloader = dataloader
self.epoch = start_epoch
self.device = device
self.data_mean = torch.tensor([i / 255 for i in mean],
device=device).view(1, 3, 1, 1)
self.data_std = torch.tensor([i / 255 for i in std],
device=device).view(1, 3, 1, 1)

def __len__(self):
return len(self.dataloader)

def __iter__(self):
if isinstance(self.dataloader.sampler,
(DistributedSampler, DistributedEvalSampler)):
self.dataloader.sampler.set_epoch(self.epoch)
self.epoch += 1
return self.prefetched_loader()

def prefetched_loader(self):
stream = torch.cuda.Stream()
first = True

for next_inputs, next_targets in self.dataloader:
with torch.cuda.stream(stream):
next_inputs = next_inputs.to(
self.device, dtype=torch.float,
non_blocking=True).sub(self.data_mean).div(self.data_std)
next_targets = next_targets.to(self.device, non_blocking=True)

if not first:
yield inputs, targets
else:
first = False

torch.cuda.current_stream().wait_stream(stream)
inputs = next_inputs
targets = next_targets

yield inputs, targets


# Inspired by github.com/PetrochukM/PyTorch-NLP/blob/master/torchnlp/samplers/
# distributed_sampler.py
class TFDistributedSampler:

def __init__(self, iterator, device='cuda:0', rank=None):
self.iterator = iterator
self.device = device
self.rank = rank
if rank is None:
if not torch.distributed.is_initialized():
raise RuntimeError('Requires `torch.distributed` to be initialized.')
self.rank = torch.distributed.get_rank()

def __iter__(self):
return self

def __next__(self):
batch = next(self.iterator)
batch = {
# Assumes that len(value) > self.rank, i.e. there needs to be data for
# each rank/GPU.
key: torch.as_tensor(
value[self.rank], device=self.device, dtype=torch.int64) for key,
value in batch.items()
}
return batch
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def create_split(split,
area_range=(0.08, 1.0)):
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
del num_batches
if split == 'eval_train':
split = 'train[:50000]'

shuffle_rng, preprocess_rng = jax.random.split(rng, 2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _build_dataset(self,
ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
if split == 'eval_train':
split = f'train[:{self.num_eval_train_examples}]'
ds = input_pipeline.create_input_iter(
split,
ds_builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from algorithmic_efficiency.workloads.imagenet.workload import \
BaseImagenetWorkload

PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()

Expand Down Expand Up @@ -54,12 +54,7 @@ def build_input_queue(self,
split: str,
data_dir: str,
global_batch_size: int):
it = self._build_dataset(data_rng, split, data_dir, global_batch_size)
for batch in it:
yield {
'inputs': batch['inputs'].float().to(DEVICE, non_blocking=True),
'targets': batch['targets'].to(DEVICE, non_blocking=True),
}
return self._build_dataset(data_rng, split, data_dir, global_batch_size)

def _build_dataset(self,
data_rng: spec.RandomState,
Expand All @@ -69,16 +64,9 @@ def _build_dataset(self,
del data_rng
is_train = split == 'train'

normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[i / 255 for i in self.train_mean],
std=[i / 255 for i in self.train_stddev])
])
eval_transform_config = transforms.Compose([
transforms.Resize(self.resize_size),
transforms.CenterCrop(self.center_crop_size),
normalize
])
transform_config = {
'train':
Expand All @@ -88,7 +76,6 @@ def _build_dataset(self,
scale=self.scale_ratio_range,
ratio=self.aspect_ratio_range),
transforms.RandomHorizontalFlip(),
normalize
]),
'eval_train':
eval_transform_config,
Expand All @@ -108,7 +95,7 @@ def _build_dataset(self,
range(self.num_eval_train_examples))

sampler = None
if PYTORCH_DDP:
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
Expand All @@ -119,13 +106,17 @@ def _build_dataset(self,
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=not PYTORCH_DDP and is_train,
shuffle=not USE_PYTORCH_DDP and is_train,
sampler=sampler,
num_workers=0,
num_workers=4,
pin_memory=True,
collate_fn=data_utils.fast_collate,
drop_last=is_train)

dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP)
dataloader = data_utils.PrefetchedWrapper(dataloader,
DEVICE,
self.train_mean,
self.train_stddev)
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)

return dataloader

Expand All @@ -137,7 +128,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
}
model.to(DEVICE)
if N_GPUS > 1:
if PYTORCH_DDP:
if USE_PYTORCH_DDP:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[RANK], output_device=RANK)
else:
Expand Down Expand Up @@ -245,7 +236,7 @@ def _eval_model_on_split(self,
total_metrics = {
k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if PYTORCH_DDP:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""MNIST workload implemented in Jax."""
import functools
import itertools
from typing import Any, Dict, Tuple

from flax import jax_utils
Expand Down Expand Up @@ -102,7 +103,11 @@ def build_input_queue(self,
split: str,
data_dir: str,
global_batch_size: int) -> Dict[str, Any]:
return self._build_dataset(data_rng, split, data_dir, global_batch_size)
ds = self._build_dataset(data_rng, split, data_dir, global_batch_size)
if split != 'train':
# Note that this stores the entire eval dataset in memory.
ds = itertools.cycle(ds)
return ds

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload

PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu')
N_GPUS = torch.cuda.device_count()

Expand Down Expand Up @@ -78,7 +78,7 @@ def _build_dataset(self,
is_train = split == 'train'

sampler = None
if PYTORCH_DDP:
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
Expand All @@ -89,13 +89,12 @@ def _build_dataset(self,
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=not PYTORCH_DDP and is_train,
shuffle=not USE_PYTORCH_DDP and is_train,
sampler=sampler,
num_workers=0,
pin_memory=True,
drop_last=is_train)
if is_train:
dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP)
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)

return dataloader

Expand All @@ -118,14 +117,9 @@ def build_input_queue(self,
global_batch_size: int) -> Dict[str, Any]:
it = self._build_dataset(data_rng, split, data_dir, global_batch_size)
for batch in it:
if isinstance(batch, dict):
inputs = batch['inputs']
targets = batch['targets']
else:
inputs, targets = batch
yield {
'inputs': inputs.to(DEVICE, non_blocking=True),
'targets': targets.to(DEVICE, non_blocking=True),
'inputs': batch['inputs'].to(DEVICE, non_blocking=True),
'targets': batch['targets'].to(DEVICE, non_blocking=True),
}

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
Expand All @@ -136,7 +130,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
}
model.to(DEVICE)
if N_GPUS > 1:
if PYTORCH_DDP:
if USE_PYTORCH_DDP:
model = DDP(model, device_ids=[RANK], output_device=RANK)
else:
model = torch.nn.DataParallel(model)
Expand Down
Loading