Skip to content

Commit e42072c

Browse files
awaelchlicarmoccatchaton
committed
Enable distributed training with CombinedDataLoader and max_size_cycle (#10374)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Thomas Chaton <[email protected]>
1 parent 9903646 commit e42072c

File tree

4 files changed

+89
-7
lines changed

4 files changed

+89
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Fixed
1111

12-
-
12+
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
13+
1314

1415

1516
## [1.5.1] - 2021-11-09

pytorch_lightning/trainer/data_loading.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
2929
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
3030
from pytorch_lightning.trainer.states import RunningStage
31-
from pytorch_lightning.trainer.supporters import CombinedLoader
31+
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
3232
from pytorch_lightning.utilities import rank_zero_warn
3333
from pytorch_lightning.utilities.apply_func import apply_to_collection
3434
from pytorch_lightning.utilities.auto_restart import (
@@ -136,14 +136,22 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
136136
if isinstance(dataloader, CombinedLoader):
137137
# apply `prepare_dataloader` on all the collection of loaders
138138
dataloader.loaders = apply_to_collection(
139-
dataloader.loaders, DataLoader, self.prepare_dataloader, shuffle, mode=mode
139+
dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode
140140
)
141+
# the length need to recomputed across all dataloaders in case of special behavior.
142+
dataloader._apply_cycle_iterator_length()
141143
return dataloader
142144

143145
# don't do anything if it's not a dataloader
144-
if not isinstance(dataloader, DataLoader):
146+
if not isinstance(dataloader, (DataLoader, CycleIterator)):
145147
return dataloader
146148

149+
cycle_iterator: Optional[CycleIterator] = None
150+
151+
if isinstance(dataloader, CycleIterator):
152+
cycle_iterator = dataloader
153+
dataloader = dataloader.loader
154+
147155
if (
148156
_fault_tolerant_training() # injects components to track the state
149157
or self._requires_distributed_sampler(dataloader) # sets the distributed sampler
@@ -153,6 +161,10 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
153161
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
154162
dataloader = self._update_dataloader(dataloader, sampler, mode=mode)
155163

164+
if cycle_iterator is not None:
165+
cycle_iterator.loader = dataloader
166+
return cycle_iterator
167+
156168
return dataloader
157169

158170
def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:

pytorch_lightning/trainer/supporters.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,19 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
457457
)
458458
state.reset()
459459

460+
def _apply_cycle_iterator_length(self) -> None:
461+
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
462+
all dataloaders."""
463+
if self.mode != "max_size_cycle":
464+
return
465+
466+
def set_len(cycle_iterator: CycleIterator, length: int) -> None:
467+
cycle_iterator.length = length
468+
469+
all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader))
470+
max_length = _nested_calc_num_data(all_lengths, max)
471+
apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length)
472+
460473
def __iter__(self) -> Any:
461474
"""Create and return an iterator, `CombinedLoaderIterator`, for the combined loader."""
462475

@@ -473,11 +486,12 @@ def __getstate__patch__(*_):
473486
return iterator
474487

475488
@staticmethod
476-
def _calc_num_batches(loaders: Any) -> Union[int, float]:
489+
def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
477490
"""Compute the length (aka the number of batches) of `CombinedLoader`.
478491
479492
Args:
480493
loaders: a collections of loaders.
494+
mode: Mode used by the CombinedDataloader
481495
482496
Returns:
483497
length: the minimum length of loaders
@@ -486,10 +500,10 @@ def _calc_num_batches(loaders: Any) -> Union[int, float]:
486500

487501
if isinstance(all_lengths, (int, float)):
488502
return all_lengths
489-
return _nested_calc_num_data(all_lengths, min)
503+
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)
490504

491505
def __len__(self) -> int:
492-
return self._calc_num_batches(self.loaders)
506+
return self._calc_num_batches(self.loaders, mode=self.mode)
493507

494508
@staticmethod
495509
def _shutdown_workers_and_reset_iterator(dataloader) -> None:

tests/trainer/test_supporters.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
)
3434
from pytorch_lightning.utilities.apply_func import apply_to_collection
3535
from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler
36+
from pytorch_lightning.utilities.data import get_len
3637
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3738
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
39+
from tests.helpers.boring_model import RandomDataset
3840

3941

4042
def test_tensor_running_accum_reset():
@@ -379,3 +381,56 @@ def _assert_dataset(loader):
379381
assert isinstance(d, CustomDataset)
380382

381383
apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)
384+
385+
386+
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
387+
def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir):
388+
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
389+
with ddp and `max_size_cycle` mode."""
390+
trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp)
391+
392+
dataloader = CombinedLoader(
393+
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
394+
)
395+
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
396+
assert len(dataloader) == 4 if replace_sampler_ddp else 8
397+
398+
for a_length in [6, 8, 10]:
399+
dataloader = CombinedLoader(
400+
{
401+
"a": DataLoader(range(a_length), batch_size=1),
402+
"b": DataLoader(range(8), batch_size=1),
403+
},
404+
mode="max_size_cycle",
405+
)
406+
407+
length = max(a_length, 8)
408+
assert len(dataloader) == length
409+
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
410+
assert len(dataloader) == length // 2 if replace_sampler_ddp else length
411+
if replace_sampler_ddp:
412+
last_batch = list(dataloader)[-1]
413+
if a_length == 6:
414+
assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
415+
elif a_length == 8:
416+
assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
417+
elif a_length == 10:
418+
assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}
419+
420+
class InfiniteDataset(IterableDataset):
421+
def __iter__(self):
422+
while True:
423+
yield 1
424+
425+
dataloader = CombinedLoader(
426+
{
427+
"a": DataLoader(InfiniteDataset(), batch_size=1),
428+
"b": DataLoader(range(8), batch_size=1),
429+
},
430+
mode="max_size_cycle",
431+
)
432+
assert get_len(dataloader) == float("inf")
433+
assert len(dataloader.loaders["b"].loader) == 8
434+
dataloader = trainer.prepare_dataloader(dataloader, shuffle=False)
435+
assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
436+
assert get_len(dataloader) == float("inf")

0 commit comments

Comments
 (0)