From d2ebae5fd2980bae3db4993cf319c916e41ce7fd Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 7 Jun 2022 18:18:14 +0200 Subject: [PATCH 01/14] Fix eval metric name --- algorithmic_efficiency/workloads/imagenet/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet/workload.py b/algorithmic_efficiency/workloads/imagenet/workload.py index dbb03d861..849b0a32a 100644 --- a/algorithmic_efficiency/workloads/imagenet/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/workload.py @@ -9,7 +9,7 @@ def __init__(self): self._param_shapes = None def has_reached_goal(self, eval_result: float) -> bool: - return eval_result['accuracy'] > self.target_value + return eval_result['validation/accuracy'] > self.target_value @property def target_value(self): From 9fef79eb38d8817088b3a29e8125dd0ff4412af4 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 7 Jun 2022 18:25:42 +0200 Subject: [PATCH 02/14] Minor clean up --- reference_submissions/imagenet/imagenet_jax/submission.py | 6 ------ .../imagenet/imagenet_pytorch/submission.py | 2 ++ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/reference_submissions/imagenet/imagenet_jax/submission.py b/reference_submissions/imagenet/imagenet_jax/submission.py index b6b88ea64..90e15b8d9 100644 --- a/reference_submissions/imagenet/imagenet_jax/submission.py +++ b/reference_submissions/imagenet/imagenet_jax/submission.py @@ -18,12 +18,6 @@ def get_batch_size(workload_name): return 128 -def cosine_decay(lr, step, total_steps): - ratio = jnp.maximum(0., step / total_steps) - mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio)) - return mult * lr - - def create_learning_rate_fn(hparams: spec.Hyperparameters, steps_per_epoch: int): """Create learning rate schedule.""" diff --git a/reference_submissions/imagenet/imagenet_pytorch/submission.py b/reference_submissions/imagenet/imagenet_pytorch/submission.py index df7a72119..85e523027 100644 --- a/reference_submissions/imagenet/imagenet_pytorch/submission.py +++ b/reference_submissions/imagenet/imagenet_pytorch/submission.py @@ -109,8 +109,10 @@ def data_selection(workload: spec.Workload, Each element of the queue is a batch of training examples and labels. """ + del workload del optimizer_state del current_param_container + del hyperparameters del global_step del rng return next(input_queue) From 259e0b498437609fd4796edd8c537f4b667f7093 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 7 Jun 2022 20:09:05 +0200 Subject: [PATCH 03/14] Fix submission_runner test --- tests/submission_runner_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index 924187015..799264a63 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -6,6 +6,7 @@ """ import copy import os +import sys from absl import flags from absl import logging @@ -15,6 +16,9 @@ import submission_runner FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +# (see https://github.com/google/model_search/pull/8). +FLAGS(sys.argv) class SubmissionRunnerTest(parameterized.TestCase): From 9ed87e9f83d1125820c44855fc97cef067cc50c0 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 8 Jun 2022 18:36:45 +0200 Subject: [PATCH 04/14] Fix MNIST workloads --- .../workloads/mnist/mnist_jax/workload.py | 33 +++++++------- .../workloads/mnist/mnist_pytorch/workload.py | 43 +++++++++++-------- .../workloads/mnist/workload.py | 15 ++++--- .../mnist/mnist_jax/submission.py | 1 + 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index 70436bfe8..558ac159e 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -1,6 +1,6 @@ """MNIST workload implemented in Jax.""" import functools -from typing import Dict, Tuple +from typing import Any, Dict, Tuple from flax import jax_utils from flax import linen as nn @@ -10,7 +10,7 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils +from algorithmic_efficiency import param_utils, data_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload @@ -58,25 +58,27 @@ def _build_dataset(self, split: str, data_dir: str, batch_size): + # TODO: choose a random split and match with PyTorch. if split == 'eval_train': - tfds_split = 'train[:50000]' + tfds_split = f'train[:{self.num_eval_train_examples}]' elif split == 'validation': - tfds_split = 'train[50000:]' + tfds_split = f'train[{self.num_train_examples}:]' else: - tfds_split = split + tfds_split = f'train[:{self.num_train_examples}]' ds = tfds.load( 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir) - ds = ds.cache() ds = ds.map(lambda x: { 'inputs': self._normalize(x['image']), 'targets': x['label'], }) - if split == 'train': - ds = ds.shuffle(1024, seed=data_rng[0]) + ds = ds.cache() + is_train = split == 'train' + if is_train: + ds = ds.shuffle(16 * batch_size, seed=data_rng[0]) ds = ds.repeat() - ds = ds.batch(batch_size) - ds = ds.batch(jax.local_device_count()) - return tfds.as_numpy(ds) + ds = ds.batch(batch_size, drop_remainder=is_train) + ds = map(data_utils.shard_numpy_ds, ds) + return iter(ds) @property def model_params_types(self): @@ -98,10 +100,8 @@ def build_input_queue(self, data_rng, split: str, data_dir: str, - global_batch_size: int): - ds = self._build_dataset(data_rng, split, data_dir, global_batch_size) - for images, labels in iter(ds): - yield images, labels, None + global_batch_size: int) -> Dict[str, Any]: + return self._build_dataset(data_rng, split, data_dir, global_batch_size) def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) @@ -168,7 +168,6 @@ def _eval_model( update_batch_norm=False) accuracy = jnp.sum(jnp.argmax(logits, axis=-1) == batch['targets']) loss = jnp.sum(self.loss_fn(batch['targets'], logits)) - num_data = len(logits) - metrics = {'accuracy': accuracy, 'loss': loss, 'num_data': num_data} + metrics = {'accuracy': accuracy, 'loss': loss} metrics = lax.psum(metrics, axis_name='batch') return metrics diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index 3fa8e425a..b76271e30 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -1,17 +1,16 @@ """MNIST workload implemented in PyTorch.""" from collections import OrderedDict import contextlib -import itertools from typing import Any, Dict, Tuple import torch from torch import nn import torch.nn.functional as F -import torch.utils.data as data_utils +import torch.utils.data as pytorch_data_utils from torchvision import transforms from torchvision.datasets import MNIST -from algorithmic_efficiency import param_utils +from algorithmic_efficiency import param_utils, data_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload @@ -38,13 +37,6 @@ def forward(self, x: spec.Tensor): return self.net(x) -class DictMNIST(MNIST): - - def __getitem__(self, index: int) -> Dict[str, Any]: - image, label = super().__getitem__(index) - return {'inputs': image, 'targets': label} - - class MnistWorkload(BaseMnistWorkload): def _build_dataset(self, @@ -58,11 +50,11 @@ def _build_dataset(self, transforms.ToTensor(), transforms.Normalize((self.train_mean,), (self.train_stddev,)) ]) - dataset = DictMNIST( + dataset = MNIST( data_dir, train=dataloader_split, download=True, transform=transform) if split != 'test': if split in ['train', 'validation']: - train_dataset, validation_dataset = data_utils.random_split( + train_dataset, validation_dataset = pytorch_data_utils.random_split( dataset, [self.num_train_examples, self.num_validation_examples], generator=torch.Generator().manual_seed(int(data_rng[0]))) @@ -71,7 +63,7 @@ def _build_dataset(self, elif split == 'validation': dataset = validation_dataset if split == 'eval_train': - dataset, _ = data_utils.random_split( + dataset, _ = pytorch_data_utils.random_split( dataset, [self.num_eval_train_examples, 60000 - self.num_eval_train_examples], @@ -79,9 +71,13 @@ def _build_dataset(self, # TODO: set seeds properly is_train = split == 'train' dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, shuffle=is_train, pin_memory=True) + dataset, + batch_size=batch_size, + shuffle=is_train, + pin_memory=True, + drop_last=is_train) if is_train: - dataloader = itertools.cycle(dataloader) + dataloader = data_utils.cycle(dataloader) return dataloader @@ -101,8 +97,18 @@ def build_input_queue(self, data_rng, split: str, data_dir: str, - global_batch_size: int): - return self._build_dataset(data_rng, split, data_dir, global_batch_size) + global_batch_size: int) -> Dict[str, Any]: + it = self._build_dataset(data_rng, split, data_dir, global_batch_size) + for batch in it: + if isinstance(batch, dict): + inputs = batch['inputs'] + targets = batch['targets'] + else: + inputs, targets = batch + yield { + 'inputs': inputs.to(DEVICE, non_blocking=True), + 'targets': targets.to(DEVICE, non_blocking=True), + } def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) @@ -175,5 +181,4 @@ def _eval_model( # Number of correct predictions. accuracy = (predicted == batch['targets']).sum().item() loss = self.loss_fn(batch['targets'], logits).sum().item() - num_data = len(logits) - return {'accuracy': accuracy, 'loss': loss, 'num_data': num_data} + return {'accuracy': accuracy, 'loss': loss} diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index fb0da28a2..ba9655111 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -2,12 +2,16 @@ import itertools import math from typing import Dict, Tuple +from absl import flags import jax +from flax import jax_utils from algorithmic_efficiency import spec import algorithmic_efficiency.random_utils as prng +FLAGS =flags.FLAGS + class BaseMnistWorkload(spec.Workload): @@ -101,11 +105,9 @@ def _eval_model_on_split(self, 'accuracy': 0., 'loss': 0., } - num_data = 0 num_batches = int(math.ceil(num_examples / global_batch_size)) - for bi, batch in enumerate(self._eval_iters[split]): - if bi > num_batches: - break + for _ in range(num_batches): + batch = next(self._eval_iters[split]) per_device_model_rngs = prng.split(model_rng, jax.local_device_count()) batch_metrics = self._eval_model(params, batch, @@ -114,5 +116,6 @@ def _eval_model_on_split(self, total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } - num_data += batch_metrics['num_data'] - return {k: float(v / num_data) for k, v in total_metrics.items()} + if FLAGS.framework == 'jax': + total_metrics = jax_utils.unreplicate(total_metrics) + return {k: float(v / num_examples) for k, v in total_metrics.items()} diff --git a/reference_submissions/mnist/mnist_jax/submission.py b/reference_submissions/mnist/mnist_jax/submission.py index 8ca6ca50d..275eed50e 100644 --- a/reference_submissions/mnist/mnist_jax/submission.py +++ b/reference_submissions/mnist/mnist_jax/submission.py @@ -52,6 +52,7 @@ def pmapped_update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], optimizer_state: spec.OptimizerState, rng: spec.RandomState) -> spec.UpdateReturn: + del hyperparameters def loss_fn(params): logits_batch, new_model_state = workload.model_fn( From 07589d3ae3f5a8ad58e22f7013dc92789edd693d Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 8 Jun 2022 18:38:58 +0200 Subject: [PATCH 05/14] Add data_utils --- algorithmic_efficiency/data_utils.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 algorithmic_efficiency/data_utils.py diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py new file mode 100644 index 000000000..90b07607b --- /dev/null +++ b/algorithmic_efficiency/data_utils.py @@ -0,0 +1,36 @@ +import jax + + +def shard_numpy_ds(xs): + """Prepare tf data for JAX + + Convert an input batch from tf Tensors to numpy arrays and reshape it to be + sharded across devices. + """ + local_device_count = jax.local_device_count() + + def _prepare(x): + # Use _numpy() for zero-copy conversion between TF and NumPy. + x = x._numpy() # pylint: disable=protected-access + + # reshape (host_batch_size, height, width, 3) to + # (local_devices, device_batch_size, height, width, 3) + return x.reshape((local_device_count, -1) + x.shape[1:]) + + return jax.tree_map(_prepare, xs) + + +# github.com/pytorch/pytorch/issues/23900#issuecomment-518858050 +def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False): + iterator = iter(iterable) + epoch = 0 + while True: + try: + batch = next(iterator) + assert len(keys) == len(batch) + yield {key: value for key, value in zip(keys, batch)} + except StopIteration: + if custom_sampler: + epoch += 1 + iterable.sampler.set_epoch(epoch) + iterator = iter(iterable) From 22f10e7407ba73828c604ea4b971ece35a5a9ccf Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 17:20:01 +0200 Subject: [PATCH 06/14] Implement DDP --- submission_runner.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 0a0f0fb57..6c154eea7 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -22,6 +22,8 @@ from absl import flags from absl import logging import tensorflow as tf +import torch +import torch.distributed as dist from algorithmic_efficiency import halton from algorithmic_efficiency import random_utils as prng @@ -32,7 +34,7 @@ tf.config.experimental.set_visible_devices([], 'GPU') # TODO(znado): make a nicer registry of workloads that lookup in. -BASE_WORKLOADS_DIR = "algorithmic_efficiency/workloads/" +BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' # Workload_path will be appended by '_pytorch' or '_jax' automatically. WORKLOADS = { @@ -79,9 +81,10 @@ 'tuning_search_space', 'reference_submissions/mnist/tuning_search_space.json', 'The path to the JSON file describing the external tuning search space.') -flags.DEFINE_integer('num_tuning_trials', - 20, - 'The number of external hyperparameter trials to run.') +flags.DEFINE_integer( + 'num_tuning_trials', + 20, + 'The number of external hyperparameter trials to run.') flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location') flags.DEFINE_enum( 'framework', @@ -294,6 +297,24 @@ def score_submission_on_workload(workload: spec.Workload, def main(_): + # Check if distributed data parallel is used. + pytorch_ddp = 'LOCAL_RANK' in os.environ + if FLAGS.framework == 'pytorch': + # From the docs: "(...) causes cuDNN to benchmark multiple convolution + # algorithms and select the fastest." + torch.backends.cudnn.benchmark = True + + if pytorch_ddp: + rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(rank) + # only log once (for local rank == 0) + if rank != 0: + def logging_pass(*args): + pass + logging.info = logging_pass + # initialize the process group + dist.init_process_group('nccl') + workload_metadata = WORKLOADS[FLAGS.workload] # extend path according to framework workload_metadata['workload_path'] = os.path.join( @@ -313,6 +334,10 @@ def main(_): FLAGS.num_tuning_trials) logging.info('Final %s score: %f', FLAGS.workload, score) + if pytorch_ddp: + # cleanup + dist.destroy_process_group() + if __name__ == '__main__': app.run(main) From 1e1da7dfe090ac513633a93f77c6fca53172c877 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 17:24:16 +0200 Subject: [PATCH 07/14] Add DistributedEvalSampler --- algorithmic_efficiency/data_utils.py | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 90b07607b..938e98031 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -1,4 +1,7 @@ import jax +import torch +import torch.distributed as dist +from torch.utils.data import Sampler def shard_numpy_ds(xs): @@ -34,3 +37,110 @@ def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False): epoch += 1 iterable.sampler.set_epoch(epoch) iterator = iter(iterable) + + +# github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py +class DistributedEvalSampler(Sampler): + r""" + DistributedEvalSampler is different from DistributedSampler. + It does NOT add extra samples to make it evenly divisible. + DistributedEvalSampler should NOT be used for training. The distributed + processes could hang forever. + See this issue for details: https://github.com/pytorch/pytorch/issues/22584 + shuffle is disabled by default + DistributedEvalSampler is for evaluation purpose where synchronization does + not happen every epoch. + Synchronization should be done outside the dataloader loop. + Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class`~torch.utils.data.DistributedSampler` instance as + a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`rank` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within + :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the + current distributed group. + shuffle (bool, optional): If ``True``, sampler will shuffle the + indices. Default: ``False`` + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + .. warning:: + In distributed mode, calling the :meth`set_epoch(epoch) ` + method at the beginning of each epoch **before** creating the + :class:`DataLoader` iterator is necessary to make shuffling work + properly across multiple epochs. Otherwise, the same ordering will be + always used. + Example:: + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__(self, + dataset, + num_replicas=None, + rank=None, + shuffle=False, + seed=0): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError('Requires distributed package to be available.') + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError('Requires distributed package to be available.') + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + # true value without extra samples + self.total_size = len(self.dataset) + indices = list(range(self.total_size)) + indices = indices[self.rank:self.total_size:self.num_replicas] + # true value without extra samples + self.num_samples = len(indices) + + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this + ensures all replicas use a different random ordering for each epoch. + Otherwise, the next iteration of this sampler will yield the same + ordering. + Arguments: + epoch (int): _epoch number. + """ + self.epoch = epoch From 0704595dbb1660c07dd36ecbe99355680d5d039d Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 17:43:12 +0200 Subject: [PATCH 08/14] Add DDP to ImageNet --- .../imagenet/imagenet_pytorch/workload.py | 85 ++++++++++++------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py index c35119a28..2a1780c5a 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py @@ -8,10 +8,12 @@ import torch from torch import nn import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algorithmic_efficiency import param_utils +from algorithmic_efficiency import param_utils, data_utils from algorithmic_efficiency import spec import algorithmic_efficiency.random_utils as prng from algorithmic_efficiency.workloads.imagenet.imagenet_pytorch.models import \ @@ -19,18 +21,10 @@ from algorithmic_efficiency.workloads.imagenet.workload import \ BaseImagenetWorkload -DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - - -# from https://github.com/pytorch/pytorch/issues/23900#issuecomment-518858050 -def cycle(iterable): - iterator = iter(iterable) - while True: - try: - images, labels = next(iterator) - yield {'inputs': images, 'targets': labels} - except StopIteration: - iterator = iter(iterable) +PYTORCH_DDP = 'LOCAL_RANK' in os.environ +RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0 +DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu') +N_GPUS = torch.cuda.device_count() class ImagenetWorkload(BaseImagenetWorkload): @@ -59,11 +53,11 @@ def build_input_queue(self, split: str, data_dir: str, global_batch_size: int): - it = iter(self._build_dataset(data_rng, split, data_dir, global_batch_size)) + it = self._build_dataset(data_rng, split, data_dir, global_batch_size) for batch in it: yield { - 'inputs': batch['inputs'].float().to(DEVICE), - 'targets': batch['targets'].to(DEVICE), + 'inputs': batch['inputs'].float().to(DEVICE, non_blocking=True), + 'targets': batch['targets'].to(DEVICE, non_blocking=True), } def _build_dataset(self, @@ -71,7 +65,8 @@ def _build_dataset(self, split: str, data_dir: str, batch_size: int): - is_train = (split == "train") + del data_rng + is_train = split == 'train' normalize = transforms.Compose([ transforms.ToTensor(), @@ -106,15 +101,36 @@ def _build_dataset(self, os.path.join(data_dir, folder[split]), transform=transform_config[split]) + if split == 'eval_train': + # We always use the same subset of the training data for evaluation. + dataset = torch.utils.data.Subset( + dataset, range(self.num_eval_train_examples)) + + sampler = None + if PYTORCH_DDP: + if is_train: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=N_GPUS, + rank=RANK, + shuffle=True) + else: + sampler = data_utils.DistributedEvalSampler( + dataset, + num_replicas=N_GPUS, + rank=RANK, + shuffle=False) + batch_size //= N_GPUS dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=is_train, - num_workers=5, + shuffle=not PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=0, pin_memory=True, drop_last=is_train) - dataloader = cycle(dataloader) + dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) return dataloader @@ -124,14 +140,19 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._param_shapes = { k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters() } - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) model.to(DEVICE) + if N_GPUS > 1: + if PYTORCH_DDP: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DDP(model, device_ids=[RANK], output_device=RANK) + else: + model = torch.nn.DataParallel(model) return model, None def _update_batch_norm(self, model, update_batch_norm): for m in model.modules(): - if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, + nn.BatchNorm3d, nn.SyncBatchNorm)): if not update_batch_norm: m.eval() m.requires_grad_(update_batch_norm) @@ -191,10 +212,9 @@ def _eval_metric(self, logits, labels): """Return the mean accuracy and loss as a dict.""" predicted = torch.argmax(logits, 1) # not accuracy, but nr. of correct predictions - accuracy = (predicted == labels).sum().item() - loss = self.loss_fn(labels, logits).sum().item() - num_data = len(logits) - return {'accuracy': accuracy, 'loss': loss, 'num_data': num_data} + accuracy = (predicted == labels).sum() + loss = self.loss_fn(labels, logits).sum() + return {'accuracy': accuracy, 'loss': loss} def _eval_model_on_split(self, split: str, @@ -212,10 +232,9 @@ def _eval_model_on_split(self, data_rng, split, data_dir, global_batch_size=global_batch_size) total_metrics = { - 'accuracy': 0., - 'loss': 0., + 'accuracy': torch.tensor(0., device=DEVICE), + 'loss': torch.tensor(0., device=DEVICE), } - num_data = 0 num_batches = int(math.ceil(num_examples / global_batch_size)) for _ in range(num_batches): batch = next(self._eval_iters[split]) @@ -230,5 +249,7 @@ def _eval_model_on_split(self, total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } - num_data += batch_metrics['num_data'] - return {k: float(v / num_data) for k, v in total_metrics.items()} + if PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} From 801591c852c120b511af8ed35b03a571037ba333 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 18:29:27 +0200 Subject: [PATCH 09/14] Fix Jax ImageNet workload --- algorithmic_efficiency/spec.py | 5 +- .../imagenet/imagenet_jax/input_pipeline.py | 53 ++++++------------- .../imagenet/imagenet_jax/workload.py | 27 +++++++--- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 01d43b26d..05f4d2b48 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -106,7 +106,10 @@ def build_input_queue(self, data_rng: RandomState, split: str, data_dir: str, - global_batch_size: int) -> Dict[str, Any]: + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None) -> Dict[str, Any]: """Build the input queue for the workload data. This is the only function that is NOT allowed to be called by submitters. diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py index 4096f3d08..e8b91110c 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py @@ -9,6 +9,8 @@ import tensorflow as tf import tensorflow_datasets as tfds +import algorithmic_efficiency.data_utils as data_utils + IMAGE_SIZE = 224 RESIZE_SIZE = 256 MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] @@ -156,16 +158,18 @@ def preprocess_for_train(image_bytes, Returns: A preprocessed image `Tensor`. """ - crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2) + # Note (runame): Cannot be done in graph mode, i.e. during ds.map(). + # Alternative? + # crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2) image = _decode_and_random_crop(image_bytes, - crop_rng, + rng, image_size, aspect_ratio_range, area_range, resize_size) image = tf.reshape(image, [image_size, image_size, 3]) - image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng) + image = tf.image.stateless_random_flip_left_right(image, seed=rng) image = normalize_image(image, mean_rgb, stddev_rgb) image = tf.image.convert_image_dtype(image, dtype=dtype) return image @@ -209,22 +213,19 @@ def create_split(split, aspect_ratio_range=(0.75, 4.0 / 3.0), area_range=(0.08, 1.0)): """Creates a split from the ImageNet dataset using TensorFlow Datasets.""" + del num_batches if split == 'eval_train': - split = 'train' + split = 'train[:50000]' shuffle_rng, preprocess_rng = jax.random.split(rng, 2) - def decode_example(example): + def decode_example(example_index, example): dtype = tf.float32 if train: - # We call ds.enumerate() to get a globally unique per-example, per-step - # index that we can fold into the RNG seed. - (example_index, example) = example per_step_preprocess_rng = tf.random.experimental.stateless_fold_in( tf.cast(preprocess_rng, tf.int64), example_index) image = preprocess_for_train(example['image'], per_step_preprocess_rng, - example_index, mean_rgb, stddev_rgb, aspect_ratio_range, @@ -246,7 +247,7 @@ def decode_example(example): 'image': tfds.decode.SkipDecoding(), }) options = tf.data.Options() - options.experimental_threading.private_threadpool_size = 48 + options.threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: @@ -256,11 +257,11 @@ def decode_example(example): ds = ds.repeat() ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0]) + # We call ds.enumerate() to get a globally unique per-example, per-step + # index that we can fold into the RNG seed. + ds = ds.enumerate() ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) - ds = ds.batch(global_batch_size, drop_remainder=True) - - if num_batches is not None: - ds = ds.take(num_batches) + ds = ds.batch(global_batch_size, drop_remainder=train) if repeat_final_dataset: ds = ds.repeat() @@ -270,25 +271,6 @@ def decode_example(example): return ds -def shard_numpy_ds(xs): - """Prepare tf data for JAX - - Convert an input batch from tf Tensors to numpy arrays and reshape it to be - sharded across devices. - """ - local_device_count = jax.local_device_count() - - def _prepare(x): - # Use _numpy() for zero-copy conversion between TF and NumPy. - x = x._numpy() # pylint: disable=protected-access - - # reshape (host_batch_size, height, width, 3) to - # (local_devices, device_batch_size, height, width, 3) - return x.reshape((local_device_count, -1) + x.shape[1:]) - - return jax.tree_map(_prepare, xs) - - def create_input_iter(split, dataset_builder, rng, @@ -309,7 +291,6 @@ def create_input_iter(split, rng, global_batch_size, train=train, - dtype=tf.float32, image_size=image_size, resize_size=resize_size, mean_rgb=mean_rgb, @@ -319,9 +300,9 @@ def create_input_iter(split, num_batches=num_batches, aspect_ratio_range=aspect_ratio_range, area_range=area_range) - it = map(shard_numpy_ds, ds) + it = map(data_utils.shard_numpy_ds, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. it = jax_utils.prefetch_to_device(it, 2) - return it + return iter(it) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py index ebdf2942b..16c3f7406 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py @@ -32,9 +32,18 @@ def build_input_queue(self, data_rng: spec.RandomState, split: str, data_dir: str, - global_batch_size: int): - return iter( - self._build_dataset(data_rng, split, data_dir, global_batch_size)) + global_batch_size: int, + cache: Optional[bool] = None, + repeat_final_dataset: Optional[bool] = None, + num_batches: Optional[int] = None): + return self._build_dataset( + data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + num_batches) def _build_dataset(self, data_rng: spec.RandomState, @@ -144,6 +153,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del mode + del rng variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( @@ -171,13 +182,14 @@ def loss_fn(self, label_batch: spec.Tensor, return xentropy def _compute_metrics(self, logits, labels): - loss = jnp.mean(self.loss_fn(labels, logits)) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + loss = jnp.sum(self.loss_fn(labels, logits)) + # not accuracy, but nr. of correct predictions + accuracy = jnp.sum(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, 'accuracy': accuracy, } - metrics = lax.pmean(metrics, axis_name='batch') + metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -213,5 +225,6 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics) + eval_metrics = jax.tree_map( + lambda x: float(x[0] / num_examples), eval_metrics) return eval_metrics From a235b8872dd8575e196ec8d84cdc0b8e6daf4089 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 20:01:10 +0200 Subject: [PATCH 10/14] Fix yapf/isort/pylint --- algorithmic_efficiency/data_utils.py | 2 +- .../imagenet/imagenet_jax/input_pipeline.py | 2 +- .../imagenet/imagenet_jax/workload.py | 19 +++++++-------- .../imagenet/imagenet_pytorch/workload.py | 24 ++++++++----------- .../workloads/mnist/mnist_jax/workload.py | 3 ++- .../workloads/mnist/workload.py | 6 ++--- submission_runner.py | 9 +++---- 7 files changed, 31 insertions(+), 34 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 938e98031..56038d96d 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -31,7 +31,7 @@ def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False): try: batch = next(iterator) assert len(keys) == len(batch) - yield {key: value for key, value in zip(keys, batch)} + yield dict(zip(keys, batch)) except StopIteration: if custom_sampler: epoch += 1 diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py index e8b91110c..eee992175 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py @@ -9,7 +9,7 @@ import tensorflow as tf import tensorflow_datasets as tfds -import algorithmic_efficiency.data_utils as data_utils +from algorithmic_efficiency import data_utils IMAGE_SIZE = 224 RESIZE_SIZE = 256 diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py index 16c3f7406..d3b4c66da 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py @@ -36,14 +36,13 @@ def build_input_queue(self, cache: Optional[bool] = None, repeat_final_dataset: Optional[bool] = None, num_batches: Optional[int] = None): - return self._build_dataset( - data_rng, - split, - data_dir, - global_batch_size, - cache, - repeat_final_dataset, - num_batches) + return self._build_dataset(data_rng, + split, + data_dir, + global_batch_size, + cache, + repeat_final_dataset, + num_batches) def _build_dataset(self, data_rng: spec.RandomState, @@ -225,6 +224,6 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map( - lambda x: float(x[0] / num_examples), eval_metrics) + eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples), + eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py index 2a1780c5a..05c0f7d18 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py @@ -7,13 +7,14 @@ import torch from torch import nn -import torch.nn.functional as F import torch.distributed as dist +import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algorithmic_efficiency import param_utils, data_utils +from algorithmic_efficiency import data_utils +from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec import algorithmic_efficiency.random_utils as prng from algorithmic_efficiency.workloads.imagenet.imagenet_pytorch.models import \ @@ -103,23 +104,17 @@ def _build_dataset(self, if split == 'eval_train': # We always use the same subset of the training data for evaluation. - dataset = torch.utils.data.Subset( - dataset, range(self.num_eval_train_examples)) + dataset = torch.utils.data.Subset(dataset, + range(self.num_eval_train_examples)) sampler = None if PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( - dataset, - num_replicas=N_GPUS, - rank=RANK, - shuffle=True) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) else: sampler = data_utils.DistributedEvalSampler( - dataset, - num_replicas=N_GPUS, - rank=RANK, - shuffle=False) + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) batch_size //= N_GPUS dataloader = torch.utils.data.DataLoader( dataset, @@ -151,8 +146,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: def _update_batch_norm(self, model, update_batch_norm): for m in model.modules(): - if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, - nn.BatchNorm3d, nn.SyncBatchNorm)): + if isinstance( + m, + (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)): if not update_batch_norm: m.eval() m.requires_grad_(update_batch_norm) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index 558ac159e..a3c433e71 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -10,7 +10,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils, data_utils +from algorithmic_efficiency import data_utils +from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index ba9655111..015d44be7 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -2,15 +2,15 @@ import itertools import math from typing import Dict, Tuple -from absl import flags -import jax +from absl import flags from flax import jax_utils +import jax from algorithmic_efficiency import spec import algorithmic_efficiency.random_utils as prng -FLAGS =flags.FLAGS +FLAGS = flags.FLAGS class BaseMnistWorkload(spec.Workload): diff --git a/submission_runner.py b/submission_runner.py index 6c154eea7..097990caa 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -81,10 +81,9 @@ 'tuning_search_space', 'reference_submissions/mnist/tuning_search_space.json', 'The path to the JSON file describing the external tuning search space.') -flags.DEFINE_integer( - 'num_tuning_trials', - 20, - 'The number of external hyperparameter trials to run.') +flags.DEFINE_integer('num_tuning_trials', + 20, + 'The number of external hyperparameter trials to run.') flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location') flags.DEFINE_enum( 'framework', @@ -309,8 +308,10 @@ def main(_): torch.cuda.set_device(rank) # only log once (for local rank == 0) if rank != 0: + def logging_pass(*args): pass + logging.info = logging_pass # initialize the process group dist.init_process_group('nccl') From 4f3dda12e9c299eb606e2c692f1f7997516c6d2b Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 20:13:06 +0200 Subject: [PATCH 11/14] Missing isort fix --- .../workloads/mnist/mnist_pytorch/workload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index b76271e30..4a41c2c5b 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -10,7 +10,8 @@ from torchvision import transforms from torchvision.datasets import MNIST -from algorithmic_efficiency import param_utils, data_utils +from algorithmic_efficiency import data_utils +from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload From 889690bf8d8547476816899d8c61683d99ce3c97 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 20:40:50 +0200 Subject: [PATCH 12/14] Adjust setup for pylint>=2.14.0 --- setup.cfg | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index b58784ff9..de0ffc48c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ profile=google # pylint configuration [pylint.MASTER] persistent=no # Pickle collected data for later comparisons. -cache-size=500 # Set the cache size for astng objects. +#cache-size=500 # Set the cache size for astng objects. # Ignore Py3 files ignore=get_references_web.py,get_references_web_single_group.py [pylint.REPORTS] @@ -151,11 +151,11 @@ ignore=get_references_web.py,get_references_web_single_group.py # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". -files-output=no +#files-output=no # Tells whether to display a full report or only the messages. reports=no # Disable the report(s) with the given id(s). -disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 +#disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 # Error message template (continued on second line) msg-template={msg_id}:{line:3} {obj}: {msg} [{symbol}] [pylint.'MESSAGES CONTROL'] @@ -165,7 +165,7 @@ enable=indexing-exception,old-raise-syntax [pylint.BASIC] # Required attributes for module, separated by a comma -required-attributes= +#required-attributes= # Regular expression which should only match the name # of functions or classes which do not require a docstring. no-docstring-rgx=(__.*__|main) @@ -212,7 +212,7 @@ good-names=main,_ # Bad variable names which should always be refused, separated by a comma bad-names= # List of builtins function names that should not be used, separated by a comma -bad-functions=input,apply,reduce +#bad-functions=input,apply,reduce # List of decorators that define properties, such as abc.abstractproperty. property-classes=abc.abstractproperty [pylint.TYPECHECK] @@ -267,21 +267,19 @@ indent-string=' ' # Do not warn about multiple statements on a single line for constructs like # if test: stmt single-line-if-stmt=y -# Make sure : in dicts and trailing commas are checked for whitespace. -no-space-check= [pylint.LOGGING] # Add logging modules. logging-modules=logging,absl.logging [pylint.MISCELLANEOUS] # Maximum line length for lambdas -short-func-length=1 +#short-func-length=1 # List of module members that should be marked as deprecated. # All of the string functions are listed in 4.1.4 Deprecated string functions # in the Python 2.4 docs. -deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint +#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint # List of exceptions that do not need to be mentioned in the Raises section of # a docstring. -ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError +#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError # Number of spaces of indent required when the last token on the preceding line # is an open (, [, or {. indent-after-paren=4 From 997dbd305230a2a4ddd4377a1a2fb060e23de967 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 9 Jun 2022 21:02:07 +0200 Subject: [PATCH 13/14] Use generator instead of list comprehension --- reference_submissions/cifar/cifar_jax/submission.py | 3 +-- reference_submissions/imagenet/imagenet_jax/submission.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/reference_submissions/cifar/cifar_jax/submission.py b/reference_submissions/cifar/cifar_jax/submission.py index 5cf61e12f..f45566092 100644 --- a/reference_submissions/cifar/cifar_jax/submission.py +++ b/reference_submissions/cifar/cifar_jax/submission.py @@ -93,8 +93,7 @@ def _loss_fn(params): update_batch_norm=True) loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) weight_penalty_params = jax.tree_leaves(params) - weight_l2 = sum( - [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 loss = loss + weight_penalty return loss, new_model_state diff --git a/reference_submissions/imagenet/imagenet_jax/submission.py b/reference_submissions/imagenet/imagenet_jax/submission.py index 90e15b8d9..39634d1bf 100644 --- a/reference_submissions/imagenet/imagenet_jax/submission.py +++ b/reference_submissions/imagenet/imagenet_jax/submission.py @@ -88,8 +88,7 @@ def _loss_fn(params): update_batch_norm=True) loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) weight_penalty_params = jax.tree_leaves(variables['params']) - weight_l2 = sum( - [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_model_state, logits) From c4a07ac76aa311af883886579344581cc4f9394c Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 10 Jun 2022 15:09:59 +0200 Subject: [PATCH 14/14] Add DDP to MNIST --- .../workloads/mnist/mnist_pytorch/workload.py | 34 +++++++++++++++---- .../workloads/mnist/workload.py | 8 +++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index 4a41c2c5b..29d2b64ab 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -1,11 +1,13 @@ """MNIST workload implemented in PyTorch.""" from collections import OrderedDict import contextlib +import os from typing import Any, Dict, Tuple import torch from torch import nn import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP import torch.utils.data as pytorch_data_utils from torchvision import transforms from torchvision.datasets import MNIST @@ -15,7 +17,10 @@ from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload -DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +PYTORCH_DDP = 'LOCAL_RANK' in os.environ +RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0 +DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu') +N_GPUS = torch.cuda.device_count() class _Model(nn.Module): @@ -71,14 +76,26 @@ def _build_dataset(self, generator=torch.Generator().manual_seed(int(data_rng[0]))) # TODO: set seeds properly is_train = split == 'train' + + sampler = None + if PYTORCH_DDP: + if is_train: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) + else: + sampler = data_utils.DistributedEvalSampler( + dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False) + batch_size //= N_GPUS dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=is_train, + shuffle=not PYTORCH_DDP and is_train, + sampler=sampler, + num_workers=0, pin_memory=True, drop_last=is_train) if is_train: - dataloader = data_utils.cycle(dataloader) + dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) return dataloader @@ -117,9 +134,12 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._param_shapes = { k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters() } - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) model.to(DEVICE) + if N_GPUS > 1: + if PYTORCH_DDP: + model = DDP(model, device_ids=[RANK], output_device=RANK) + else: + model = torch.nn.DataParallel(model) return model, None def model_fn( @@ -180,6 +200,6 @@ def _eval_model( update_batch_norm=False) _, predicted = torch.max(logits.data, 1) # Number of correct predictions. - accuracy = (predicted == batch['targets']).sum().item() - loss = self.loss_fn(batch['targets'], logits).sum().item() + accuracy = (predicted == batch['targets']).sum() + loss = self.loss_fn(batch['targets'], logits).sum() return {'accuracy': accuracy, 'loss': loss} diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 015d44be7..3a0db1d97 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -1,16 +1,19 @@ """MNIST workload parent class.""" import itertools import math +import os from typing import Dict, Tuple from absl import flags from flax import jax_utils import jax +import torch.distributed as dist from algorithmic_efficiency import spec import algorithmic_efficiency.random_utils as prng FLAGS = flags.FLAGS +PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseMnistWorkload(spec.Workload): @@ -118,4 +121,9 @@ def _eval_model_on_split(self, } if FLAGS.framework == 'jax': total_metrics = jax_utils.unreplicate(total_metrics) + elif PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + if FLAGS.framework == 'pytorch': + total_metrics = {k: v.item() for k, v in total_metrics.items()} return {k: float(v / num_examples) for k, v in total_metrics.items()}