Skip to content

Commit 235e692

Browse files
authored
Fabric: do set_epoch for batch_sampler.sampler (#16841)
1 parent beced48 commit 235e692

File tree

11 files changed

+90
-87
lines changed

11 files changed

+90
-87
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))
1616

1717

18+
- Added support for automatically calling `set_epoch` on the `dataloader.batch_sampler.sampler` ([#16841](https://github.com/Lightning-AI/lightning/pull/16841))
19+
20+
1821
### Changed
1922

2023
- Checkpoint saving and loading redesign ([#16434](https://github.com/Lightning-AI/lightning/pull/16434))
@@ -59,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5962

6063
### Fixed
6164

65+
- Fixed issue where the wrapped dataloader `iter()` would be called twice ([#16841](https://github.com/Lightning-AI/lightning/pull/16841))
66+
6267
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
6368
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))
6469

src/lightning/fabric/utilities/data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,25 @@ def _replace_value_in_saved_args(
414414
return True, args, kwargs
415415

416416
return False, args, kwargs
417+
418+
419+
def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
420+
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
421+
422+
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
423+
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
424+
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
425+
"""
426+
objects = set()
427+
# check dataloader.sampler
428+
if (sampler := getattr(dataloader, "sampler", None)) is not None:
429+
objects.add(sampler)
430+
# check dataloader.batch_sampler.sampler
431+
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
432+
sampler := getattr(batch_sampler, "sampler", None)
433+
) is not None:
434+
objects.add(sampler)
435+
for obj in objects:
436+
set_epoch = getattr(obj, "set_epoch", None)
437+
if callable(set_epoch):
438+
set_epoch(epoch)

src/lightning/fabric/wrappers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
2727
from lightning.fabric.strategies import Strategy
2828
from lightning.fabric.utilities import move_data_to_device
29+
from lightning.fabric.utilities.data import _set_sampler_epoch
2930
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
3031
from lightning.fabric.utilities.types import Optimizable
3132

@@ -168,20 +169,17 @@ def __len__(self) -> int:
168169
return len(self._dataloader)
169170

170171
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
171-
if hasattr(self._dataloader.sampler, "set_epoch"):
172-
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
173-
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
174-
# In Lite, we take care of this boilerplate code.
175-
self._dataloader.sampler.set_epoch(self._num_iter_calls)
172+
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
173+
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
174+
# In Fabric, we take care of this boilerplate code.
175+
_set_sampler_epoch(self._dataloader, self._num_iter_calls)
176176
self._num_iter_calls += 1
177177

178-
iterator = iter(self._dataloader)
179178
if self._device is None:
180-
yield from iterator
181-
return
182-
183-
for item in iterator:
184-
yield move_data_to_device(item, self._device)
179+
yield from iter(self._dataloader)
180+
else:
181+
for item in self._dataloader:
182+
yield move_data_to_device(item, self._device)
185183

186184

187185
def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
from torch import Tensor
2222

2323
import lightning.pytorch as pl
24+
from lightning.fabric.utilities.data import _set_sampler_epoch
2425
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2526
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
2627
from lightning.pytorch.loops.loop import _Loop
2728
from lightning.pytorch.loops.progress import BatchProgress
28-
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
29+
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher
2930
from lightning.pytorch.trainer import call
3031
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
3132
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from typing import Optional, Union
1616

1717
import lightning.pytorch as pl
18-
from lightning.fabric.utilities.data import _auto_add_worker_init_fn
18+
from lightning.fabric.utilities.data import _auto_add_worker_init_fn, _set_sampler_epoch
1919
from lightning.pytorch.loops import _Loop
2020
from lightning.pytorch.loops.fetchers import _DataFetcher
2121
from lightning.pytorch.loops.progress import Progress
2222
from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop
23-
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch
23+
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
2424
from lightning.pytorch.trainer import call
2525
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
2626
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
import lightning.pytorch as pl
88
from lightning.fabric.utilities import move_data_to_device
9+
from lightning.fabric.utilities.data import _set_sampler_epoch
910
from lightning.pytorch.callbacks import BasePredictionWriter
1011
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
1112
from lightning.pytorch.loops.loop import _Loop
1213
from lightning.pytorch.loops.progress import Progress
13-
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
14+
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher
1415
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
1516
from lightning.pytorch.strategies import DDPSpawnStrategy
1617
from lightning.pytorch.trainer import call

src/lightning/pytorch/loops/utilities.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Any, Callable, Generator, Iterable, Optional, Tuple
15+
from typing import Any, Callable, Generator, Optional, Tuple
1616

1717
import torch
1818
import torch.distributed as dist
@@ -123,28 +123,6 @@ def _reset_progress(loop: _Loop) -> None:
123123
_reset_progress(v)
124124

125125

126-
def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
127-
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.
128-
129-
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
130-
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
131-
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
132-
"""
133-
objects = set()
134-
# check dataloader.sampler
135-
if (sampler := getattr(dataloader, "sampler", None)) is not None:
136-
objects.add(sampler)
137-
# check dataloader.batch_sampler.sampler
138-
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
139-
sampler := getattr(batch_sampler, "sampler", None)
140-
) is not None:
141-
objects.add(sampler)
142-
for obj in objects:
143-
set_epoch = getattr(obj, "set_epoch", None)
144-
if callable(set_epoch):
145-
set_epoch(epoch)
146-
147-
148126
def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher:
149127
lightning_module = trainer.lightning_module
150128
if trainer.testing:

tests/tests_fabric/test_wrappers.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pytest
1818
import torch
1919
from tests_fabric.helpers.runif import RunIf
20-
from torch.utils.data import DistributedSampler
20+
from torch.utils.data import BatchSampler, DistributedSampler
2121
from torch.utils.data.dataloader import DataLoader
2222

2323
from lightning.fabric.fabric import Fabric
@@ -232,24 +232,36 @@ def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
232232
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
233233

234234

235-
def test_fabric_dataloader_distributed_sampler_set_epoch():
235+
@pytest.mark.parametrize("use_batch_sampler", (False, True))
236+
def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler):
236237
"""Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
237-
sampler = DistributedSampler(range(3), num_replicas=2, rank=0)
238+
dataset = range(3)
239+
sampler = DistributedSampler(dataset, num_replicas=2, rank=0)
238240
sampler.set_epoch = Mock()
239-
dataloader = DataLoader(range(3), sampler=sampler)
241+
242+
if not use_batch_sampler:
243+
dataloader = DataLoader(dataset, sampler=sampler)
244+
else:
245+
batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False)
246+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
247+
240248
fabric_dataloader = _FabricDataLoader(dataloader)
241249
iterator_epoch_0 = iter(fabric_dataloader)
242-
dataloader.sampler.set_epoch.assert_not_called()
250+
sampler.set_epoch.assert_not_called()
251+
243252
next(iterator_epoch_0)
244253
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
245-
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
254+
assert sampler.set_epoch.mock_calls == [call(0)]
255+
246256
next(iterator_epoch_0)
247-
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
257+
assert sampler.set_epoch.mock_calls == [call(0)]
258+
248259
iterator_epoch_1 = iter(fabric_dataloader)
249-
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
260+
assert sampler.set_epoch.mock_calls == [call(0)]
261+
250262
next(iterator_epoch_1)
251263
# with every new iterator call, the epoch increases
252-
assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
264+
assert sampler.set_epoch.mock_calls == [call(0), call(1)]
253265

254266

255267
def test_fabric_optimizer_wraps():

tests/tests_fabric/utilities/test_data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from unittest.mock import Mock
23

34
import numpy as np
45
import pytest
@@ -12,6 +13,7 @@
1213
_get_dataloader_init_args_and_kwargs,
1314
_replace_dunder_methods,
1415
_replace_value_in_saved_args,
16+
_set_sampler_epoch,
1517
_update_dataloader,
1618
_WrapAttrTag,
1719
has_iterable_dataset,
@@ -525,3 +527,23 @@ def __init__(self, indices=None, **kwargs):
525527
dataloader = ArrayAttributeDataloader(dataset)
526528
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
527529
assert dl_kwargs["indices"] is dataloader.indices
530+
531+
532+
def test_set_sampler_epoch():
533+
# No samplers
534+
dataloader = Mock()
535+
dataloader.sampler = None
536+
dataloader.batch_sampler = None
537+
_set_sampler_epoch(dataloader, 55)
538+
539+
# set_epoch not callable
540+
dataloader = Mock()
541+
dataloader.sampler.set_epoch = None
542+
dataloader.batch_sampler.set_epoch = None
543+
_set_sampler_epoch(dataloader, 55)
544+
545+
# set_epoch callable
546+
dataloader = Mock()
547+
_set_sampler_epoch(dataloader, 55)
548+
dataloader.sampler.set_epoch.assert_called_once_with(55)
549+
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)

tests/tests_pytorch/loops/test_prediction_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def predict_step(self, batch, batch_idx):
5151
assert trainer.predict_loop.predictions == []
5252

5353

54-
@pytest.mark.parametrize("replace_sampler_ddp", (False, True))
55-
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sampler_ddp):
54+
@pytest.mark.parametrize("use_distributed_sampler", (False, True))
55+
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
5656
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
5757
trainer = Trainer(
5858
default_root_dir=tmp_path,
@@ -63,14 +63,14 @@ def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sample
6363
strategy="ddp",
6464
devices=1,
6565
accelerator="cpu",
66-
replace_sampler_ddp=replace_sampler_ddp,
66+
use_distributed_sampler=use_distributed_sampler,
6767
)
6868

6969
class MyModel(BoringModel):
7070
def predict_dataloader(self):
7171
dataset = RandomDataset(32, 64)
7272
sampler = None
73-
if not replace_sampler_ddp:
73+
if not use_distributed_sampler:
7474
sampler = DistributedSampler(dataset)
7575
return DataLoader(dataset, sampler=sampler)
7676

tests/tests_pytorch/loops/test_utilities.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)