From dd5dcc9a01845e0b0bc660ca9fc4e429ee4afa3d Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 13 Jun 2022 22:51:07 +0200 Subject: [PATCH 01/13] Address Zack's comments --- .../imagenet/imagenet_jax/input_pipeline.py | 2 -- .../workloads/imagenet/imagenet_jax/workload.py | 2 ++ .../workloads/mnist/mnist_jax/workload.py | 7 ++++++- .../workloads/mnist/mnist_pytorch/workload.py | 12 +++--------- algorithmic_efficiency/workloads/mnist/workload.py | 5 +---- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py index eee992175..d02ce2099 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/input_pipeline.py @@ -214,8 +214,6 @@ def create_split(split, area_range=(0.08, 1.0)): """Creates a split from the ImageNet dataset using TensorFlow Datasets.""" del num_batches - if split == 'eval_train': - split = 'train[:50000]' shuffle_rng, preprocess_rng = jax.random.split(rng, 2) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py index d3b4c66da..4c4252154 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_jax/workload.py @@ -57,6 +57,8 @@ def _build_dataset(self, ds_builder = tfds.builder('imagenet2012:5.*.*', data_dir=data_dir) ds_builder.download_and_prepare() train = split == 'train' + if split == 'eval_train': + split = f'train[:{self.num_eval_train_examples}]' ds = input_pipeline.create_input_iter( split, ds_builder, diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index a3c433e71..8fe748515 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -1,5 +1,6 @@ """MNIST workload implemented in Jax.""" import functools +import itertools from typing import Any, Dict, Tuple from flax import jax_utils @@ -102,7 +103,11 @@ def build_input_queue(self, split: str, data_dir: str, global_batch_size: int) -> Dict[str, Any]: - return self._build_dataset(data_rng, split, data_dir, global_batch_size) + ds = self._build_dataset(data_rng, split, data_dir, global_batch_size) + if split != 'train': + # Note that this stores the entire eval dataset in memory. + ds = itertools.cycle(ds) + return ds def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: init_val = jnp.ones((1, 28, 28, 1), jnp.float32) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index 29d2b64ab..720b1475e 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -94,8 +94,7 @@ def _build_dataset(self, num_workers=0, pin_memory=True, drop_last=is_train) - if is_train: - dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) + dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) return dataloader @@ -118,14 +117,9 @@ def build_input_queue(self, global_batch_size: int) -> Dict[str, Any]: it = self._build_dataset(data_rng, split, data_dir, global_batch_size) for batch in it: - if isinstance(batch, dict): - inputs = batch['inputs'] - targets = batch['targets'] - else: - inputs, targets = batch yield { - 'inputs': inputs.to(DEVICE, non_blocking=True), - 'targets': targets.to(DEVICE, non_blocking=True), + 'inputs': batch['inputs'].to(DEVICE, non_blocking=True), + 'targets': batch['targets'].to(DEVICE, non_blocking=True), } def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 3a0db1d97..404f70861 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -1,5 +1,4 @@ """MNIST workload parent class.""" -import itertools import math import os from typing import Dict, Tuple @@ -99,10 +98,8 @@ def _eval_model_on_split(self, """Run a full evaluation of the model.""" data_rng, model_rng = prng.split(rng, 2) if split not in self._eval_iters: - eval_iter = self.build_input_queue( + self._eval_iters[split] = self.build_input_queue( data_rng, split, data_dir, global_batch_size=global_batch_size) - # Note that this stores the entire eval dataset in memory. - self._eval_iters[split] = itertools.cycle(eval_iter) total_metrics = { 'accuracy': 0., From 0452f1f407f7961ee3ad6426e42ef64ddb890846 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 Jun 2022 11:43:59 +0200 Subject: [PATCH 02/13] Change ImageNet batch size + fix minor diffs --- .../imagenet/imagenet_jax/submission.py | 2 +- .../imagenet/imagenet_pytorch/submission.py | 25 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/reference_submissions/imagenet/imagenet_jax/submission.py b/reference_submissions/imagenet/imagenet_jax/submission.py index 39634d1bf..b88659b71 100644 --- a/reference_submissions/imagenet/imagenet_jax/submission.py +++ b/reference_submissions/imagenet/imagenet_jax/submission.py @@ -15,7 +15,7 @@ def get_batch_size(workload_name): # Return the global batch size. del workload_name - return 128 + return 512 def create_learning_rate_fn(hparams: spec.Hyperparameters, diff --git a/reference_submissions/imagenet/imagenet_pytorch/submission.py b/reference_submissions/imagenet/imagenet_pytorch/submission.py index 85e523027..77b68118a 100644 --- a/reference_submissions/imagenet/imagenet_pytorch/submission.py +++ b/reference_submissions/imagenet/imagenet_pytorch/submission.py @@ -11,8 +11,8 @@ def get_batch_size(workload_name): # Return the global batch size. - batch_sizes = {'imagenet': 128} - return batch_sizes[workload_name] + del workload_name + return 512 def init_optimizer_state(workload: spec.Workload, @@ -20,34 +20,36 @@ def init_optimizer_state(workload: spec.Workload, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: - del workload del model_state del rng - base_lr = hyperparameters.learning_rate * get_batch_size('imagenet') / 256. + batch_size = get_batch_size('imagenet') + base_lr = hyperparameters.learning_rate * batch_size / 256. optimizer_state = { 'optimizer': torch.optim.SGD( model_params.parameters(), lr=base_lr, momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2) + weight_decay=hyperparameters.l2, + nesterov=True) } + steps_per_epoch = workload.num_train_examples // batch_size scheduler1 = LinearLR( optimizer_state['optimizer'], - start_factor=1e-5, + start_factor=1e-10, end_factor=1., - total_iters=hyperparameters.warmup_epochs) + total_iters=hyperparameters.warmup_epochs * steps_per_epoch) cosine_epochs = max( hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs) + optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) optimizer_state['scheduler'] = SequentialLR( optimizer_state['optimizer'], schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs]) + milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) return optimizer_state @@ -88,10 +90,7 @@ def update_params( loss.backward() optimizer_state['optimizer'].step() - - steps_per_epoch = workload.num_train_examples // get_batch_size('imagenet') - if (global_step + 1) % steps_per_epoch == 0: - optimizer_state['scheduler'].step() + optimizer_state['scheduler'].step() return (optimizer_state, current_param_container, new_model_state) From 955daf2b7f38697408dbfbbeed476359e98c3c2d Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 Jun 2022 11:58:35 +0200 Subject: [PATCH 03/13] Add fast_collate fn and PrefetchedWrapper for ImageNet PyTorch workload --- algorithmic_efficiency/data_utils.py | 71 ++++++++++++++++++- .../imagenet/imagenet_pytorch/workload.py | 23 ++---- 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 56038d96d..1146b77e1 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -1,6 +1,9 @@ import jax +import numpy as np import torch import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data import DistributedSampler from torch.utils.data import Sampler @@ -33,7 +36,7 @@ def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False): assert len(keys) == len(batch) yield dict(zip(keys, batch)) except StopIteration: - if custom_sampler: + if custom_sampler and isinstance(iterable, DataLoader): epoch += 1 iterable.sampler.set_epoch(epoch) iterator = iter(iterable) @@ -144,3 +147,69 @@ def set_epoch(self, epoch): epoch (int): _epoch number. """ self.epoch = epoch + + +# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ +# ConvNets/image_classification/dataloaders.py +def fast_collate(batch, memory_format=torch.contiguous_format): + imgs = [img[0] for img in batch] + targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) + w = imgs[0].size[0] + h = imgs[0].size[1] + tensor = torch.zeros( + (len(imgs), 3, h, w), + dtype=torch.uint8).contiguous(memory_format=memory_format) + for i, img in enumerate(imgs): + nump_array = np.asarray(img, dtype=np.uint8) + if nump_array.ndim < 3: + nump_array = np.expand_dims(nump_array, axis=-1) + nump_array = np.rollaxis(nump_array, 2) + tensor[i] += torch.from_numpy(nump_array.copy()) + return tensor, targets + + +# Inspired by +# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ +# ConvNets/image_classification/dataloaders.py +class PrefetchedWrapper: + + def __init__(self, dataloader, device, mean, std, start_epoch=0): + self.dataloader = dataloader + self.epoch = start_epoch + self.device = device + self.data_mean = torch.tensor([i / 255 for i in mean], + device=device).view(1, 3, 1, 1) + self.data_std = torch.tensor([i / 255 for i in std], + device=device).view(1, 3, 1, 1) + + def __len__(self): + return len(self.dataloader) + + def __iter__(self): + if isinstance(self.dataloader.sampler, + (DistributedSampler, DistributedEvalSampler)): + self.dataloader.sampler.set_epoch(self.epoch) + self.epoch += 1 + return self.prefetched_loader() + + def prefetched_loader(self): + stream = torch.cuda.Stream() + first = True + + for next_inputs, next_targets in self.dataloader: + with torch.cuda.stream(stream): + next_inputs = next_inputs.to( + self.device, dtype=torch.float, + non_blocking=True).sub(self.data_mean).div(self.data_std) + next_targets = next_targets.to(self.device, non_blocking=True) + + if not first: + yield inputs, targets + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + inputs = next_inputs + targets = next_targets + + yield inputs, targets diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py index 05c0f7d18..809f697d0 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py @@ -54,12 +54,7 @@ def build_input_queue(self, split: str, data_dir: str, global_batch_size: int): - it = self._build_dataset(data_rng, split, data_dir, global_batch_size) - for batch in it: - yield { - 'inputs': batch['inputs'].float().to(DEVICE, non_blocking=True), - 'targets': batch['targets'].to(DEVICE, non_blocking=True), - } + return self._build_dataset(data_rng, split, data_dir, global_batch_size) def _build_dataset(self, data_rng: spec.RandomState, @@ -69,16 +64,9 @@ def _build_dataset(self, del data_rng is_train = split == 'train' - normalize = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize( - mean=[i / 255 for i in self.train_mean], - std=[i / 255 for i in self.train_stddev]) - ]) eval_transform_config = transforms.Compose([ transforms.Resize(self.resize_size), transforms.CenterCrop(self.center_crop_size), - normalize ]) transform_config = { 'train': @@ -88,7 +76,6 @@ def _build_dataset(self, scale=self.scale_ratio_range, ratio=self.aspect_ratio_range), transforms.RandomHorizontalFlip(), - normalize ]), 'eval_train': eval_transform_config, @@ -121,10 +108,14 @@ def _build_dataset(self, batch_size=batch_size, shuffle=not PYTORCH_DDP and is_train, sampler=sampler, - num_workers=0, + num_workers=4, pin_memory=True, + collate_fn=data_utils.fast_collate, drop_last=is_train) - + dataloader = data_utils.PrefetchedWrapper(dataloader, + DEVICE, + self.train_mean, + self.train_stddev) dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) return dataloader From d39aab844ae84fbe826a378cf1fee9934b96a255 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 Jun 2022 12:02:24 +0200 Subject: [PATCH 04/13] Fix README typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3f72a6d52..d24701850 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ Docker is the easiest way to enable PyTorch/JAX GPU support on Linux since only ```bash python3 submission_runner.py \ --framework=jax \ - --workload=mnist |\ + --workload=mnist \ --submission_path=baselines/mnist/mnist_jax/submission.py \ --tuning_search_space=baselines/mnist/tuning_search_space.json ``` From c2a1b6c83f66b256d1a9db768b90169f44ff7485 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 Jun 2022 12:04:49 +0200 Subject: [PATCH 05/13] Ignore bash scripts and output files --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index d53d91445..3fd247b7e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ env/ venv/ workdir/ makefile +*.out +*.sh From c89e1a41906659d1ce44e7e8cd403b422ca4f091 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 15 Jun 2022 12:15:01 +0200 Subject: [PATCH 06/13] Rename PYTORCH_DDP to USE_PYTORCH_DDP --- .../imagenet/imagenet_pytorch/workload.py | 14 +++++++------- .../workloads/mnist/mnist_pytorch/workload.py | 12 ++++++------ algorithmic_efficiency/workloads/mnist/workload.py | 4 ++-- submission_runner.py | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py index 809f697d0..480896a2a 100644 --- a/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet/imagenet_pytorch/workload.py @@ -22,8 +22,8 @@ from algorithmic_efficiency.workloads.imagenet.workload import \ BaseImagenetWorkload -PYTORCH_DDP = 'LOCAL_RANK' in os.environ -RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0 +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0 DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu') N_GPUS = torch.cuda.device_count() @@ -95,7 +95,7 @@ def _build_dataset(self, range(self.num_eval_train_examples)) sampler = None - if PYTORCH_DDP: + if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) @@ -106,7 +106,7 @@ def _build_dataset(self, dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=not PYTORCH_DDP and is_train, + shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, num_workers=4, pin_memory=True, @@ -116,7 +116,7 @@ def _build_dataset(self, DEVICE, self.train_mean, self.train_stddev) - dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) + dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) return dataloader @@ -128,7 +128,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: } model.to(DEVICE) if N_GPUS > 1: - if PYTORCH_DDP: + if USE_PYTORCH_DDP: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[RANK], output_device=RANK) else: @@ -236,7 +236,7 @@ def _eval_model_on_split(self, total_metrics = { k: v + batch_metrics[k] for k, v in total_metrics.items() } - if PYTORCH_DDP: + if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index 720b1475e..694facc33 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -17,8 +17,8 @@ from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload -PYTORCH_DDP = 'LOCAL_RANK' in os.environ -RANK = int(os.environ['LOCAL_RANK']) if PYTORCH_DDP else 0 +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0 DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu') N_GPUS = torch.cuda.device_count() @@ -78,7 +78,7 @@ def _build_dataset(self, is_train = split == 'train' sampler = None - if PYTORCH_DDP: + if USE_PYTORCH_DDP: if is_train: sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True) @@ -89,12 +89,12 @@ def _build_dataset(self, dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, - shuffle=not PYTORCH_DDP and is_train, + shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, num_workers=0, pin_memory=True, drop_last=is_train) - dataloader = data_utils.cycle(dataloader, custom_sampler=PYTORCH_DDP) + dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) return dataloader @@ -130,7 +130,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: } model.to(DEVICE) if N_GPUS > 1: - if PYTORCH_DDP: + if USE_PYTORCH_DDP: model = DDP(model, device_ids=[RANK], output_device=RANK) else: model = torch.nn.DataParallel(model) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index 404f70861..8014f09bb 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -12,7 +12,7 @@ import algorithmic_efficiency.random_utils as prng FLAGS = flags.FLAGS -PYTORCH_DDP = 'LOCAL_RANK' in os.environ +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ class BaseMnistWorkload(spec.Workload): @@ -118,7 +118,7 @@ def _eval_model_on_split(self, } if FLAGS.framework == 'jax': total_metrics = jax_utils.unreplicate(total_metrics) - elif PYTORCH_DDP: + elif USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) if FLAGS.framework == 'pytorch': diff --git a/submission_runner.py b/submission_runner.py index 097990caa..d6fed3e59 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -297,13 +297,13 @@ def score_submission_on_workload(workload: spec.Workload, def main(_): # Check if distributed data parallel is used. - pytorch_ddp = 'LOCAL_RANK' in os.environ + use_pytorch_ddp = 'LOCAL_RANK' in os.environ if FLAGS.framework == 'pytorch': # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True - if pytorch_ddp: + if use_pytorch_ddp: rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(rank) # only log once (for local rank == 0) @@ -335,7 +335,7 @@ def logging_pass(*args): FLAGS.num_tuning_trials) logging.info('Final %s score: %f', FLAGS.workload, score) - if pytorch_ddp: + if use_pytorch_ddp: # cleanup dist.destroy_process_group() From 676149e454e538a62a33e7d02ead60d169f715ea Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 16:47:22 +0200 Subject: [PATCH 07/13] Add DDP instructions to README --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index d24701850..d09d90f38 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,14 @@ python3 submission_runner.py \ --tuning_search_space=baselines/mnist/tuning_search_space.json ``` +When using multiple GPUs on a single node it is recommended to use PyTorch's +[distributed data parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). +To do so, simply replace `python3` by +```bash +torchrun --standalone --nnodes=1 --nproc_per_node=N_GPUS +``` +where `N_GPUS` is the number of available GPUs on the node. + ## Rules The rules for the MLCommons Algorithmic Efficency benchmark can be found in the seperate [rules document](RULES.md). Suggestions, clarifications and questions can be raised via pull requests. From ba0daf2475d7935c23dee7397c9ebe6f3069410e Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 17:15:09 +0200 Subject: [PATCH 08/13] Add TFDistributedSampler --- algorithmic_efficiency/data_utils.py | 37 +++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 1146b77e1..137859510 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -8,7 +8,7 @@ def shard_numpy_ds(xs): - """Prepare tf data for JAX + """Prepare tf data for JAX or PyTorch DDP. Convert an input batch from tf Tensors to numpy arrays and reshape it to be sharded across devices. @@ -19,8 +19,9 @@ def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. x = x._numpy() # pylint: disable=protected-access - # reshape (host_batch_size, height, width, 3) to - # (local_devices, device_batch_size, height, width, 3) + # Reshape (global_batch_size, ...) to + # (local_device_count, per_device_batch_size, ...). + # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1) + x.shape[1:]) return jax.tree_map(_prepare, xs) @@ -57,7 +58,7 @@ class DistributedEvalSampler(Sampler): Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each - process can pass a :class`~torch.utils.data.DistributedSampler` instance as + process can pass a :class`~DistributedEvalSampler` instance as a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: @@ -213,3 +214,31 @@ def prefetched_loader(self): targets = next_targets yield inputs, targets + + +# Inspired by github.com/PetrochukM/PyTorch-NLP/blob/master/torchnlp/samplers/ +# distributed_sampler.py +class TFDistributedSampler: + + def __init__(self, iterator, device='cuda:0', rank=None): + self.iterator = iterator + self.device = device + self.rank = rank + if rank is None: + if not torch.distributed.is_initialized(): + raise RuntimeError('Requires `torch.distributed` to be initialized.') + self.rank = torch.distributed.get_rank() + + def __iter__(self): + return self + + def __next__(self): + batch = next(self.iterator) + batch = { + # Assumes that len(value) > self.rank, i.e. there needs to be data for + # each rank/GPU. + key: torch.as_tensor( + value[self.rank], device=self.device, dtype=torch.int64) for key, + value in batch.items() + } + return batch From d1cac47e79f709154da37ee9b4bdbf63db8776f6 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 17:19:49 +0200 Subject: [PATCH 09/13] Adjust WMT input pipeline for DDP --- .../workloads/wmt/input_pipeline.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/input_pipeline.py b/algorithmic_efficiency/workloads/wmt/input_pipeline.py index 6a3f857f5..a608dfa83 100644 --- a/algorithmic_efficiency/workloads/wmt/input_pipeline.py +++ b/algorithmic_efficiency/workloads/wmt/input_pipeline.py @@ -3,10 +3,10 @@ import os from typing import Dict, List, Optional, Union -import jax import tensorflow as tf import tensorflow_datasets as tfds +from algorithmic_efficiency import data_utils from algorithmic_efficiency.workloads.wmt import tokenizer AUTOTUNE = tf.data.AUTOTUNE @@ -211,7 +211,7 @@ def preprocess_wmt_data(dataset: tf.data.Dataset, train: bool, shuffle_buffer_size: int = 1024, max_length: int = 512, - per_device_batch_size: int = 256): + batch_size: int = 256): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): @@ -230,10 +230,10 @@ def filter_fn(x): dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng[0]) dataset = dataset.repeat() dataset = pack_dataset(dataset, max_length) - dataset = dataset.batch(per_device_batch_size, drop_remainder=train) + dataset = dataset.batch(batch_size, drop_remainder=train) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - per_device_batch_size, + batch_size, padded_shapes={'inputs': max_length, 'targets': max_length}, padding_values={'inputs': 0, 'targets': 0}, drop_remainder=train) @@ -273,21 +273,19 @@ def get_wmt_dataset(data_rng, ds, vocab_path=vocab_path, vocab_size=vocab_size, max_corpus_chars=10**7) ds = ds.map(tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) - num_devices = jax.local_device_count() - per_device_batch_size = global_batch_size // num_devices ds = preprocess_wmt_data( ds, data_rng, train=is_training, - per_device_batch_size=per_device_batch_size, + batch_size=global_batch_size, max_length=256) if num_batches: ds = ds.take(num_batches) - ds = ds.batch(num_devices) - if repeat_final_dataset: ds = ds.repeat() + ds = map(data_utils.shard_numpy_ds, ds) + return ds, sp_tokenizer From 102aaaf5a04e4d78bb5455787fb7dca08f70a056 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 17:20:51 +0200 Subject: [PATCH 10/13] Fix WMT Jax workload --- .../workloads/wmt/wmt_jax/workload.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 9e56b31bf..94fae02e9 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -19,12 +19,6 @@ from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload -def _pad_examples(desired_batch_size, x): - """Expand batch to desired size by repeating last slice.""" - batch_pad = desired_batch_size - x.shape[0] - return np.concatenate([x, np.tile(x[-1], (batch_pad, 1))], axis=0) - - def _per_host_sum_pmap(in_tree): """Execute psum on in_tree's leaves over one device per host.""" host2devices = collections.defaultdict(list) @@ -192,18 +186,10 @@ def translate_and_calculate_bleu(self, num_batches, max_predict_length: int): """Translates the `predict_ds` and calculates the BLEU score.""" - n_devices = jax.local_device_count() logging.info('Translating evaluation dataset.') sources, references, predictions = [], [], [] for _ in range(num_batches): pred_batch = next(ds_iter) - # Handle final odd-sized batch by padding instead of dropping it. - cur_pred_batch_size = pred_batch['inputs'].shape[0] - if cur_pred_batch_size % n_devices: - padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) - pred_batch = jax.tree_map( - functools.partial(_pad_examples, padded_size), # pylint: disable=cell-var-from-loop - pred_batch) cache = self.initialize_cache(pred_batch['inputs']) predicted = self.predict_step(pred_batch['inputs'], params, @@ -214,7 +200,7 @@ def translate_and_calculate_bleu(self, inputs = _to_host(pred_batch['inputs']) targets = _to_host(pred_batch['targets']) # Iterate through non-padding examples of batch. - for i, s in enumerate(predicted[:cur_pred_batch_size]): + for i, s in enumerate(predicted): sources.append(self._decode_tokens(inputs[i])) references.append(self._decode_tokens(targets[i])) predictions.append(self._decode_tokens(s)) From 4c57f0565ec8b0980be4a461350b56866ea497f3 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 17:29:45 +0200 Subject: [PATCH 11/13] Add DDP to WMT --- algorithmic_efficiency/workloads/wmt/bleu.py | 10 +- .../workloads/wmt/wmt_pytorch/workload.py | 98 +++++++++---------- .../workloads/wmt/workload.py | 47 +++++---- 3 files changed, 85 insertions(+), 70 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index e25e751da..9d50e12b6 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -138,10 +138,12 @@ def compute_bleu_matches(reference_corpus, translation_corpus, max_order=4): possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[ngram] - return (np.array(matches_by_order), - np.array(possible_matches_by_order), - np.array(reference_length), - np.array(translation_length)) + return [ + np.array(matches_by_order), + np.array(possible_matches_by_order), + np.array(reference_length), + np.array(translation_length) + ] def bleu_partial(ref_lines, hyp_lines, case_sensitive=False): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index b28c6baa7..97fe24a84 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -1,5 +1,6 @@ """WMT workload implemented in PyTorch.""" import contextlib +import os from typing import Dict, Optional, Tuple from absl import logging @@ -7,8 +8,11 @@ import numpy as np import tensorflow as tf import torch +import torch.distributed as dist import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from algorithmic_efficiency import data_utils from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.wmt import bleu @@ -16,7 +20,10 @@ from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload -DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0 +DEVICE = torch.device(f'cuda:{RANK}' if torch.cuda.is_available() else 'cpu') +N_GPUS = torch.cuda.device_count() def _jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False): @@ -31,17 +38,17 @@ def _pytorch_to_jax(x: spec.Tensor): class CrossEntropyLoss(torch.nn.CrossEntropyLoss): - def forward(self, logits, targets, label_smoothing=0.1): - vocab_size = logits.shape[-1] + def forward(self, input, target, label_smoothing=0.1): + vocab_size = input.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * np.log(confidence) + (vocab_size - 1) * low_confidence * np.log(low_confidence + 1e-20)) - one_hot_targets = F.one_hot(targets, num_classes=vocab_size) + one_hot_targets = F.one_hot(target, num_classes=vocab_size) soft_targets = torch.where(one_hot_targets == 1, confidence, low_confidence) loss = super().forward( - input=logits.transpose(-2, -1), target=soft_targets.transpose(-2, -1)) + input=input.transpose(-2, -1), target=soft_targets.transpose(-2, -1)) return loss - normalizing_constant @@ -70,7 +77,7 @@ def compute_weighted_cross_entropy(self, (str(logits.shape), str(targets.shape))) loss_fn = CrossEntropyLoss(reduction='none') - if torch.cuda.device_count() > 1: + if N_GPUS > 1 and not USE_PYTORCH_DDP: loss_fn = torch.nn.DataParallel(loss_fn) loss = loss_fn(logits, targets, label_smoothing=label_smoothing) @@ -84,23 +91,23 @@ def compute_weighted_cross_entropy(self, def predict_step(self, inputs, params, eos_id, max_decode_len, beam_size=4): """Predict translation with fast decoding beam search on a batch.""" # This means that decoding will always happen on a single GPU! - params = params.module if isinstance(params, - torch.nn.DataParallel) else params + params = params.module if isinstance(params, (torch.nn.DataParallel, + DDP)) else params params.eval() encoder = params.encoder - if torch.cuda.device_count() > 1: + if N_GPUS > 1 and not USE_PYTORCH_DDP: encoder = torch.nn.DataParallel(encoder) encoded_inputs = torch.repeat_interleave( encoder(inputs), repeats=beam_size, dim=0) raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0) decoder = params.decoder - if torch.cuda.device_count() > 1: + if N_GPUS > 1 and not USE_PYTORCH_DDP: decoder = torch.nn.DataParallel(decoder) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] - flat_ids = _jax_to_pytorch(flat_ids) + flat_ids = _jax_to_pytorch(flat_ids).to(DEVICE) flat_logits, new_flat_cache = decoder( flat_ids, encoded_inputs, @@ -144,19 +151,19 @@ def translate_and_calculate_bleu(self, num_batches: int, max_predict_length: int): """Translates the `ds_iter` and calculates the BLEU score.""" - n_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 logging.info('Translating evaluation dataset.') sources, references, predictions = [], [], [] for _ in range(num_batches): pred_batch = next(ds_iter) inputs = pred_batch['inputs'] targets = pred_batch['targets'] - # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = inputs.shape[0] - if cur_pred_batch_size % n_devices: - padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) - inputs = self.pad_examples(inputs, padded_size) # pylint: disable=cell-var-from-loop - targets = self.pad_examples(targets, padded_size) + if not USE_PYTORCH_DDP: + # Handle final odd-sized batch by padding instead of dropping it. + if cur_pred_batch_size % N_GPUS: + padded_size = int(np.ceil(cur_pred_batch_size / N_GPUS) * N_GPUS) + inputs = self.pad_examples(inputs, padded_size) # pylint: disable=cell-var-from-loop + targets = self.pad_examples(targets, padded_size) predicted = self.predict_step(inputs, params, decode.EOS_ID, @@ -167,13 +174,15 @@ def translate_and_calculate_bleu(self, sources.append(self._decode_tokens(inputs[i])) references.append(self._decode_tokens(targets[i])) predictions.append(self._decode_tokens(s)) - logging.info("Translation: %d predictions %d references %d sources.", - len(predictions), - len(references), - len(sources)) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) + if USE_PYTORCH_DDP: + # Sync matches across devices. + for idx, array in enumerate(bleu_matches): + tensor = torch.as_tensor(array, device=DEVICE) + dist.all_reduce(tensor) + bleu_matches[idx] = tensor.cpu().numpy() bleu_score = bleu.complete_bleu(*bleu_matches) return bleu_score @@ -183,9 +192,12 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._param_shapes = { k: spec.ShapeTuple(v.shape) for k, v in model.named_parameters() } - if torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) model.to(DEVICE) + if N_GPUS > 1: + if USE_PYTORCH_DDP: + model = DDP(model, device_ids=[RANK], output_device=RANK) + else: + model = torch.nn.DataParallel(model) return model, None def model_fn( @@ -238,14 +250,20 @@ def build_input_queue(self, global_batch_size, num_batches, repeat_final_dataset) - for batch in np_iter: - batch = { - key: torch.as_tensor(value, device=DEVICE, - dtype=torch.int64).view(-1, value.shape[-1]) - for key, - value in batch.items() - } - yield batch + if USE_PYTORCH_DDP: + return data_utils.TFDistributedSampler(np_iter, device=DEVICE, rank=RANK) + + def _input_queue_generator(): + for batch in np_iter: + batch = { + key: torch.as_tensor(value, device=DEVICE, + dtype=torch.int64).view(-1, value.shape[-1]) + for key, + value in batch.items() + } + yield batch + + return _input_queue_generator() def eval_step(self, params, batch): """Calculate evaluation metrics on a batch.""" @@ -260,24 +278,6 @@ def eval_step(self, params, batch): update_batch_norm=False) return self.compute_summed_metrics(logits, targets, weights) - def evaluate(self, - params: spec.ParameterContainer, - eval_ds: tf.data.Dataset, - num_eval_steps: int): - """Evaluate the model and return a dictionary with the metrics.""" - logging.info('Gathering evaluation metrics.') - eval_metrics = { - 'loss': 0., - 'accuracy': 0., - 'denominator': 0, - } - eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types - for _, eval_batch in zip(range(num_eval_steps), eval_iter): - metrics = self.eval_step(params, eval_batch) - eval_metrics = {k: v + metrics[k] for k, v in eval_metrics.items()} - denominator = eval_metrics.pop('denominator') - return {k: float(v / denominator) for k, v in eval_metrics.items()} - @property def model_params_types(self): if self._param_shapes is None: diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algorithmic_efficiency/workloads/wmt/workload.py index 1ec80bac8..fb4871a15 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algorithmic_efficiency/workloads/wmt/workload.py @@ -1,9 +1,12 @@ import math +import os from typing import Dict, Optional +from absl import flags import jax import numpy as np import torch +import torch.distributed as dist from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.wmt import decode @@ -11,6 +14,8 @@ VOCAB_PATH = './wmt_256/sentencepiece_model' WORKDIR = './wmt_256' +USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ +FLAGS = flags.FLAGS class BaseWmtWorkload(spec.Workload): @@ -36,15 +41,18 @@ def loss_type(self): @property def num_train_examples(self): + # wmt17_translate/de-en 'train' split size return 5906184 @property def num_eval_train_examples(self): - return 3004 + # same as `num_validation_examples` + return 3000 @property def num_validation_examples(self): - return 3004 + # wmt14_translate/de-en 'validation' split size + return 3000 @property def num_test_examples(self): @@ -75,7 +83,9 @@ def build_input_queue(self, repeat_final_dataset: bool = False): is_training = split == 'train' if split == 'eval_train': - split = 'train' + # Without the '+1' only `num_eval_train_examples-1` examples are used + # since one example is filtered out in the input pipeline. + split = f'train[:{self.num_eval_train_examples+1}]' ds, self._tokenizer = input_pipeline.get_wmt_dataset( data_rng, split, @@ -87,7 +97,6 @@ def build_input_queue(self, reverse_translation=True, repeat_final_dataset=repeat_final_dataset) for batch in iter(ds): - batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access yield batch def _eval_model_on_split(self, @@ -99,6 +108,7 @@ def _eval_model_on_split(self, rng: spec.RandomState, data_dir: str) -> Dict[str, float]: """Run a full evaluation of the model.""" + del model_state num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: # These iterators will repeat indefinitely. @@ -109,27 +119,30 @@ def _eval_model_on_split(self, global_batch_size, num_batches, repeat_final_dataset=True) - eval_metrics = [] + + eval_metrics = {} for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) metrics = self.eval_step(params, eval_batch) - eval_metrics.append(metrics) - eval_metrics_sums = {k: 0.0 for k in eval_metrics[0].keys()} - for m in eval_metrics: - for k, v in m.items(): - eval_metrics_sums[k] += v - eval_denominator = eval_metrics_sums.pop("denominator") - eval_results = jax.tree_map( - lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums) - - bleu_score = self.translate_and_calculate_bleu( + for metric_name, metric_value in metrics.items(): + if metric_name not in eval_metrics: + eval_metrics[metric_name] = 0.0 + eval_metrics[metric_name] += metric_value + if USE_PYTORCH_DDP: + for metric in eval_metrics.values(): + dist.all_reduce(metric) + if FLAGS.framework == 'pytorch': + eval_metrics = {k: v.item() for k, v in eval_metrics.items()} + eval_denominator = eval_metrics.pop('denominator') + eval_results = jax.tree_map(lambda x: float(x / eval_denominator), + eval_metrics) + + eval_results['bleu'] = self.translate_and_calculate_bleu( params=params, ds_iter=self._eval_iters[split], num_batches=num_batches, max_predict_length=256) - eval_results['bleu'] = bleu_score return eval_results def compute_summed_metrics(self, logits, labels, weights): From b6b6af803a6993160d95ab25fe6cc0d57e251c14 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 21 Jun 2022 17:40:16 +0200 Subject: [PATCH 12/13] Ignore pylint error --- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 97fe24a84..f9d301302 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -38,7 +38,7 @@ def _pytorch_to_jax(x: spec.Tensor): class CrossEntropyLoss(torch.nn.CrossEntropyLoss): - def forward(self, input, target, label_smoothing=0.1): + def forward(self, input, target, label_smoothing=0.1): # pylint: disable=redefined-builtin vocab_size = input.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) From 5b9ef217d2f5bdc15302990f03c1094151cdd4ce Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 23 Jun 2022 14:20:04 +0200 Subject: [PATCH 13/13] Minor fixes --- .../workloads/wmt/input_pipeline.py | 10 +++++----- .../workloads/wmt/wmt_pytorch/workload.py | 17 +---------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/input_pipeline.py b/algorithmic_efficiency/workloads/wmt/input_pipeline.py index a608dfa83..e7540d1d7 100644 --- a/algorithmic_efficiency/workloads/wmt/input_pipeline.py +++ b/algorithmic_efficiency/workloads/wmt/input_pipeline.py @@ -210,8 +210,8 @@ def preprocess_wmt_data(dataset: tf.data.Dataset, data_rng, train: bool, shuffle_buffer_size: int = 1024, - max_length: int = 512, - batch_size: int = 256): + max_length: int = 256, + global_batch_size: int = 128): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): @@ -230,10 +230,10 @@ def filter_fn(x): dataset = dataset.shuffle(shuffle_buffer_size, seed=data_rng[0]) dataset = dataset.repeat() dataset = pack_dataset(dataset, max_length) - dataset = dataset.batch(batch_size, drop_remainder=train) + dataset = dataset.batch(global_batch_size, drop_remainder=train) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - batch_size, + global_batch_size, padded_shapes={'inputs': max_length, 'targets': max_length}, padding_values={'inputs': 0, 'targets': 0}, drop_remainder=train) @@ -277,7 +277,7 @@ def get_wmt_dataset(data_rng, ds, data_rng, train=is_training, - batch_size=global_batch_size, + global_batch_size=global_batch_size, max_length=256) if num_batches: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index f9d301302..f68693d5b 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -137,14 +137,6 @@ def tokens_ids_to_logits(flat_ids, flat_cache): # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] - # Utils for prediction and BLEU calculation - # ---------------------------------------------------------------------------- - - def pad_examples(self, x, desired_batch_size): - """Expand batch to desired size by repeating last slice.""" - batch_pad = desired_batch_size - x.shape[0] - return torch.cat([x, torch.tile(x[-1], (batch_pad, 1))], dim=0) - def translate_and_calculate_bleu(self, params: spec.ParameterContainer, ds_iter: tf.data.Dataset, @@ -157,20 +149,13 @@ def translate_and_calculate_bleu(self, pred_batch = next(ds_iter) inputs = pred_batch['inputs'] targets = pred_batch['targets'] - cur_pred_batch_size = inputs.shape[0] - if not USE_PYTORCH_DDP: - # Handle final odd-sized batch by padding instead of dropping it. - if cur_pred_batch_size % N_GPUS: - padded_size = int(np.ceil(cur_pred_batch_size / N_GPUS) * N_GPUS) - inputs = self.pad_examples(inputs, padded_size) # pylint: disable=cell-var-from-loop - targets = self.pad_examples(targets, padded_size) predicted = self.predict_step(inputs, params, decode.EOS_ID, max_predict_length) # Iterate through non-padding examples of batch. - for i, s in enumerate(predicted[:cur_pred_batch_size]): + for i, s in enumerate(predicted): sources.append(self._decode_tokens(inputs[i])) references.append(self._decode_tokens(targets[i])) predictions.append(self._decode_tokens(s))