Skip to content

Fix bugs and implement DDP #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 13, 2022
146 changes: 146 additions & 0 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import jax
import torch
import torch.distributed as dist
from torch.utils.data import Sampler


def shard_numpy_ds(xs):
"""Prepare tf data for JAX

Convert an input batch from tf Tensors to numpy arrays and reshape it to be
sharded across devices.
"""
local_device_count = jax.local_device_count()

def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
x = x._numpy() # pylint: disable=protected-access

# reshape (host_batch_size, height, width, 3) to
# (local_devices, device_batch_size, height, width, 3)
return x.reshape((local_device_count, -1) + x.shape[1:])

return jax.tree_map(_prepare, xs)


# github.com/pytorch/pytorch/issues/23900#issuecomment-518858050
def cycle(iterable, keys=('inputs', 'targets'), custom_sampler=False):
iterator = iter(iterable)
epoch = 0
while True:
try:
batch = next(iterator)
assert len(keys) == len(batch)
yield dict(zip(keys, batch))
except StopIteration:
if custom_sampler:
epoch += 1
iterable.sampler.set_epoch(epoch)
iterator = iter(iterable)


# github.com/SeungjunNah/DeepDeblur-PyTorch/blob/master/src/data/sampler.py
class DistributedEvalSampler(Sampler):
r"""
DistributedEvalSampler is different from DistributedSampler.
It does NOT add extra samples to make it evenly divisible.
DistributedEvalSampler should NOT be used for training. The distributed
processes could hang forever.
See this issue for details: https://github.com/pytorch/pytorch/issues/22584
shuffle is disabled by default
DistributedEvalSampler is for evaluation purpose where synchronization does
not happen every epoch.
Synchronization should be done outside the dataloader loop.
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class`~torch.utils.data.DistributedSampler` instance as
a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`rank` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within
:attr:`num_replicas`. By default, :attr:`rank` is retrieved from the
current distributed group.
shuffle (bool, optional): If ``True``, sampler will shuffle the
indices. Default: ``False``
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
.. warning::
In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>`
method at the beginning of each epoch **before** creating the
:class:`DataLoader` iterator is necessary to make shuffling work
properly across multiple epochs. Otherwise, the same ordering will be
always used.
Example::
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""

def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=False,
seed=0):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
# true value without extra samples
self.total_size = len(self.dataset)
indices = list(range(self.total_size))
indices = indices[self.rank:self.total_size:self.num_replicas]
# true value without extra samples
self.num_samples = len(indices)

self.shuffle = shuffle
self.seed = seed

def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples

return iter(indices)

def __len__(self):
return self.num_samples

def set_epoch(self, epoch):
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this
ensures all replicas use a different random ordering for each epoch.
Otherwise, the next iteration of this sampler will yield the same
ordering.
Arguments:
epoch (int): _epoch number.
"""
self.epoch = epoch
5 changes: 4 additions & 1 deletion algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def build_input_queue(self,
data_rng: RandomState,
split: str,
data_dir: str,
global_batch_size: int) -> Dict[str, Any]:
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Dict[str, Any]:
"""Build the input queue for the workload data.

This is the only function that is NOT allowed to be called by submitters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import tensorflow as tf
import tensorflow_datasets as tfds

from algorithmic_efficiency import data_utils

IMAGE_SIZE = 224
RESIZE_SIZE = 256
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
Expand Down Expand Up @@ -156,16 +158,18 @@ def preprocess_for_train(image_bytes,
Returns:
A preprocessed image `Tensor`.
"""
crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2)
# Note (runame): Cannot be done in graph mode, i.e. during ds.map().
# Alternative?
# crop_rng, flip_rng = tf.random.experimental.stateless_split(rng, 2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@znado This doesn't work in graph mode (see my comment), what do you think is the best alternative?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your current change is fine for now, it's not good to reuse seeds but lets just add it to the list of issues. what was the error you got? it's really weird that this couldn't be run in graph mode (but I believe it lol), all this fn does is call random_uniform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like "stateless_split iterating over tf.Tensor is not allowed in Graph execution" (according to my google search history lol). It's especially weird because here, stateless_fold_in() works just fine and is also called inside of the ds.map() call, i.e. is also executed in graph mode. And it is calling the same underlying function. So the issue might be related to the shape argument in stateless_random_uniform(), but I haven't looked into it further.


image = _decode_and_random_crop(image_bytes,
crop_rng,
rng,
image_size,
aspect_ratio_range,
area_range,
resize_size)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng)
image = tf.image.stateless_random_flip_left_right(image, seed=rng)
image = normalize_image(image, mean_rgb, stddev_rgb)
image = tf.image.convert_image_dtype(image, dtype=dtype)
return image
Expand Down Expand Up @@ -209,22 +213,19 @@ def create_split(split,
aspect_ratio_range=(0.75, 4.0 / 3.0),
area_range=(0.08, 1.0)):
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
del num_batches
if split == 'eval_train':
split = 'train'
split = 'train[:50000]'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. should we instead in the caller function, pass in split='train[:{num_eval_train_examples}]'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes a lot of sense, already changed it.


shuffle_rng, preprocess_rng = jax.random.split(rng, 2)

def decode_example(example):
def decode_example(example_index, example):
dtype = tf.float32
if train:
# We call ds.enumerate() to get a globally unique per-example, per-step
# index that we can fold into the RNG seed.
(example_index, example) = example
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
tf.cast(preprocess_rng, tf.int64), example_index)
image = preprocess_for_train(example['image'],
per_step_preprocess_rng,
example_index,
mean_rgb,
stddev_rgb,
aspect_ratio_range,
Expand All @@ -246,7 +247,7 @@ def decode_example(example):
'image': tfds.decode.SkipDecoding(),
})
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = 48
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)

if cache:
Expand All @@ -256,11 +257,11 @@ def decode_example(example):
ds = ds.repeat()
ds = ds.shuffle(16 * global_batch_size, seed=shuffle_rng[0])

# We call ds.enumerate() to get a globally unique per-example, per-step
# index that we can fold into the RNG seed.
ds = ds.enumerate()
ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(global_batch_size, drop_remainder=True)

if num_batches is not None:
ds = ds.take(num_batches)
ds = ds.batch(global_batch_size, drop_remainder=train)

if repeat_final_dataset:
ds = ds.repeat()
Expand All @@ -270,25 +271,6 @@ def decode_example(example):
return ds


def shard_numpy_ds(xs):
"""Prepare tf data for JAX

Convert an input batch from tf Tensors to numpy arrays and reshape it to be
sharded across devices.
"""
local_device_count = jax.local_device_count()

def _prepare(x):
# Use _numpy() for zero-copy conversion between TF and NumPy.
x = x._numpy() # pylint: disable=protected-access

# reshape (host_batch_size, height, width, 3) to
# (local_devices, device_batch_size, height, width, 3)
return x.reshape((local_device_count, -1) + x.shape[1:])

return jax.tree_map(_prepare, xs)


def create_input_iter(split,
dataset_builder,
rng,
Expand All @@ -309,7 +291,6 @@ def create_input_iter(split,
rng,
global_batch_size,
train=train,
dtype=tf.float32,
image_size=image_size,
resize_size=resize_size,
mean_rgb=mean_rgb,
Expand All @@ -319,9 +300,9 @@ def create_input_iter(split,
num_batches=num_batches,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range)
it = map(shard_numpy_ds, ds)
it = map(data_utils.shard_numpy_ds, ds)

# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
it = jax_utils.prefetch_to_device(it, 2)

return it
return iter(it)
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ def build_input_queue(self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int):
return iter(
self._build_dataset(data_rng, split, data_dir, global_batch_size))
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None):
return self._build_dataset(data_rng,
split,
data_dir,
global_batch_size,
cache,
repeat_final_dataset,
num_batches)

def _build_dataset(self,
data_rng: spec.RandomState,
Expand Down Expand Up @@ -144,6 +152,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
Expand Down Expand Up @@ -171,13 +181,14 @@ def loss_fn(self, label_batch: spec.Tensor,
return xentropy

def _compute_metrics(self, logits, labels):
loss = jnp.mean(self.loss_fn(labels, logits))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
loss = jnp.sum(self.loss_fn(labels, logits))
# not accuracy, but nr. of correct predictions
accuracy = jnp.sum(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
metrics = lax.pmean(metrics, axis_name='batch')
metrics = lax.psum(metrics, axis_name='batch')
return metrics

def _eval_model_on_split(self,
Expand Down Expand Up @@ -213,5 +224,6 @@ def _eval_model_on_split(self,
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

eval_metrics = jax.tree_map(lambda x: x / num_examples, eval_metrics)
eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
eval_metrics)
return eval_metrics
Loading