Skip to content

Commit 4398db2

Browse files
awaelchliadamviolapre-commit-ci[bot]Bordatchaton
authored andcommitted
Fix _should_reload_dl_epoch causing inconsistent validation dataloader reloading (#11036)
Co-authored-by: Adam Viola <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent a85033a commit 4398db2

File tree

8 files changed

+143
-70
lines changed

8 files changed

+143
-70
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))
1313
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
1414
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
15+
- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948))
1516

1617
## [1.5.7] - 2021-12-21
1718

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _reload_evaluation_dataloaders(self) -> None:
166166
"""Reloads dataloaders if necessary."""
167167
if self.trainer.testing:
168168
self.trainer.reset_test_dataloader()
169-
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
169+
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
170170
self.trainer.reset_val_dataloader()
171171

172172
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def on_advance_start(self) -> None:
205205
model = self.trainer.lightning_module
206206

207207
# reset train dataloader
208-
if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch:
208+
if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl:
209209
self.trainer.reset_train_dataloader(model)
210210
self._is_fresh_start_epoch = False
211211

pytorch_lightning/trainer/data_loading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TrainerDataLoadingMixin(ABC):
5050
# this is just a summary on variables used in this abstract class,
5151
# the proper values/initialisation should be done in child class
5252
val_check_interval: float
53+
reload_dataloaders_every_n_epochs: int
5354
tpu_local_core_rank: int
5455
train_dataloader: DataLoader
5556
limit_train_batches: Union[int, float]
@@ -70,6 +71,21 @@ class TrainerDataLoadingMixin(ABC):
7071
accelerator: Accelerator
7172
accelerator_connector: AcceleratorConnector
7273
call_hook: Callable
74+
current_epoch: int
75+
_last_train_dl_reload_epoch: int
76+
_last_val_dl_reload_epoch: int
77+
78+
@property
79+
def _should_reload_train_dl(self) -> bool:
80+
"""Check if train dataloader should be reloaded."""
81+
n_epochs = self.reload_dataloaders_every_n_epochs
82+
return n_epochs and (self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs)
83+
84+
@property
85+
def _should_reload_val_dl(self) -> bool:
86+
"""Check if validation dataloader should be reloaded."""
87+
n_epochs = self.reload_dataloaders_every_n_epochs
88+
return n_epochs and (self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)
7389

7490
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
7591
if not isinstance(dataloader, DataLoader):
@@ -415,6 +431,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
415431
" you want to see logs for the training epoch."
416432
)
417433

434+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
435+
self._last_train_dl_reload_epoch = self.current_epoch
436+
418437
def _reset_eval_dataloader(
419438
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
420439
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
@@ -529,6 +548,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
529548
RunningStage.VALIDATING, model=pl_module
530549
)
531550

551+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
552+
self._last_val_dl_reload_epoch = self.current_epoch
553+
532554
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
533555
"""Resets the test dataloader and determines the number of batches.
534556

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,8 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:
663663
self.num_val_batches = []
664664
self.test_dataloaders = None
665665
self.val_dataloaders = None
666+
self._last_train_dl_reload_epoch = float("-inf")
667+
self._last_val_dl_reload_epoch = float("-inf")
666668

667669
# when true, print evaluation results in .validate() and .test()
668670
self.verbose_evaluate = True
@@ -752,6 +754,8 @@ def _fit_impl(
752754
self.state.fn = TrainerFn.FITTING
753755
self.state.status = TrainerStatus.RUNNING
754756
self.training = True
757+
self._last_train_dl_reload_epoch = float("-inf")
758+
self._last_val_dl_reload_epoch = float("-inf")
755759

756760
# if a datamodule comes in as the second arg, then fix it for the user
757761
if isinstance(train_dataloaders, LightningDataModule):
@@ -1826,12 +1830,6 @@ def progress_bar_dict(self) -> dict:
18261830
return self.progress_bar_callback.get_metrics(self, ref_model)
18271831
return self.progress_bar_metrics
18281832

1829-
@property
1830-
def _should_reload_dl_epoch(self) -> bool:
1831-
"""Check if dataloader should be reloaded in the current epoch."""
1832-
n_epochs = self.reload_dataloaders_every_n_epochs
1833-
return n_epochs and (not self.current_epoch % n_epochs)
1834-
18351833
@property
18361834
def disable_validation(self) -> bool:
18371835
"""Check if validation is disabled during training."""

tests/deprecated_api/test_remove_1-6.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,12 @@ def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir):
118118
limit_val_batches=0.3,
119119
reload_dataloaders_every_epoch=True,
120120
max_epochs=3,
121+
num_sanity_val_steps=0,
121122
)
122123
trainer.fit(model)
123124
trainer.test()
124125

125-
expected_sequence = (
126-
[call.val_dataloader()] + [call.train_dataloader(), call.val_dataloader()] * 3 + [call.test_dataloader()]
127-
)
126+
expected_sequence = [call.train_dataloader(), call.val_dataloader()] * 3 + [call.test_dataloader()]
128127
assert tracker.mock_calls == expected_sequence
129128

130129

tests/models/test_hooks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,6 @@ def call(hook, fn, *args, **kwargs):
878878
*batch_transfer * batches,
879879
dict(name="train_dataloader"),
880880
*batch_transfer * batches,
881-
dict(name="val_dataloader"),
882881
*batch_transfer * batches,
883882
dict(
884883
name="on_save_checkpoint",

tests/trainer/test_dataloaders.py

Lines changed: 112 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,17 +1122,12 @@ def test_dataloaders_load_only_once(tmpdir):
11221122
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]
11231123

11241124

1125-
def test_dataloaders_load_only_once_val_interval(tmpdir):
1125+
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
11261126
model = BoringModel()
11271127

11281128
# logger file to get meta
11291129
trainer = Trainer(
1130-
default_root_dir=tmpdir,
1131-
limit_train_batches=10,
1132-
limit_val_batches=10,
1133-
val_check_interval=0.3,
1134-
reload_dataloaders_every_n_epochs=True,
1135-
max_epochs=3,
1130+
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
11361131
)
11371132

11381133
tracker = Mock()
@@ -1145,34 +1140,33 @@ def test_dataloaders_load_only_once_val_interval(tmpdir):
11451140
tracker.attach_mock(model.test_dataloader, "test_dataloader")
11461141

11471142
trainer.fit(model)
1148-
trainer.test(model)
11491143

11501144
# verify the sequence
1151-
expected_sequence = [
1152-
call.val_dataloader(),
1153-
call.train_dataloader(),
1154-
call.val_dataloader(),
1155-
call.val_dataloader(),
1156-
call.val_dataloader(),
1157-
call.train_dataloader(),
1158-
call.val_dataloader(),
1159-
call.val_dataloader(),
1160-
call.val_dataloader(),
1161-
call.train_dataloader(),
1162-
call.val_dataloader(),
1163-
call.val_dataloader(),
1164-
call.val_dataloader(),
1165-
call.test_dataloader(),
1166-
]
1145+
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
11671146
assert tracker.mock_calls == expected_sequence
11681147

11691148

1170-
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
1171-
model = BoringModel()
1149+
@pytest.mark.parametrize("n", [1, 2])
1150+
def test_dataloaders_load_every_n_epochs(tmpdir, n):
1151+
train_reload_epochs, val_reload_epochs = [], []
1152+
1153+
class TestModel(BoringModel):
1154+
def train_dataloader(self):
1155+
train_reload_epochs.append(self.current_epoch)
1156+
return super().train_dataloader()
1157+
1158+
def val_dataloader(self):
1159+
val_reload_epochs.append(self.current_epoch)
1160+
return super().val_dataloader()
1161+
1162+
model = TestModel()
11721163

1173-
# logger file to get meta
11741164
trainer = Trainer(
1175-
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
1165+
default_root_dir=tmpdir,
1166+
limit_train_batches=0.3,
1167+
limit_val_batches=0.3,
1168+
reload_dataloaders_every_n_epochs=n,
1169+
max_epochs=5,
11761170
)
11771171

11781172
tracker = Mock()
@@ -1185,44 +1179,113 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
11851179
tracker.attach_mock(model.test_dataloader, "test_dataloader")
11861180

11871181
trainer.fit(model)
1182+
trainer.test(model)
1183+
1184+
# Verify the sequence
1185+
expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first
1186+
if n == 1:
1187+
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4
1188+
elif n == 2:
1189+
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
1190+
expected_sequence += [call.test_dataloader()]
11881191

1189-
# verify the sequence
1190-
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
11911192
assert tracker.mock_calls == expected_sequence
11921193

1194+
# Verify epoch of reloads
1195+
if n == 1:
1196+
assert train_reload_epochs == [0, 1, 2, 3, 4]
1197+
assert val_reload_epochs == [0, 1, 2, 3, 4]
1198+
elif n == 2:
1199+
assert train_reload_epochs == [0, 2, 4]
1200+
assert val_reload_epochs == [0, 2, 4]
11931201

1194-
@pytest.mark.parametrize("n", [1, 2])
1195-
def test_dataloaders_load_every_n_epochs(tmpdir, n):
1196-
model = BoringModel()
1202+
1203+
@pytest.mark.parametrize(
1204+
"n, train_reload_epochs_expect, val_reload_epochs_expect",
1205+
[
1206+
# Sanity check at epoch 0 creates a validation dataloader, but validation is
1207+
# checked (and in this case reloaded) every n epochs starting from epoch n-1
1208+
(3, [0, 2, 4, 6, 8], [0, 2, 5, 8]),
1209+
(5, [0, 2, 4, 6, 8], [0, 4, 9]),
1210+
],
1211+
)
1212+
def test_dataloaders_load_every_n_epochs_infrequent_val(
1213+
tmpdir, n, train_reload_epochs_expect, val_reload_epochs_expect
1214+
):
1215+
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
1216+
train_reload_epochs, val_reload_epochs = [], []
1217+
1218+
class TestModel(BoringModel):
1219+
def train_dataloader(self):
1220+
train_reload_epochs.append(self.current_epoch)
1221+
return super().train_dataloader()
1222+
1223+
def val_dataloader(self):
1224+
val_reload_epochs.append(self.current_epoch)
1225+
return super().val_dataloader()
1226+
1227+
model = TestModel()
11971228

11981229
trainer = Trainer(
11991230
default_root_dir=tmpdir,
12001231
limit_train_batches=0.3,
12011232
limit_val_batches=0.3,
1202-
reload_dataloaders_every_n_epochs=n,
1233+
check_val_every_n_epoch=n,
1234+
reload_dataloaders_every_n_epochs=2,
1235+
max_epochs=10,
1236+
)
1237+
model.test_dataloader = Mock(wraps=model.test_dataloader)
1238+
1239+
trainer.fit(model)
1240+
trainer.test(model)
1241+
1242+
# Verify epoch of reloads
1243+
assert train_reload_epochs == train_reload_epochs_expect
1244+
assert val_reload_epochs == val_reload_epochs_expect
1245+
1246+
model.test_dataloader.assert_called_once()
1247+
1248+
1249+
def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):
1250+
"""Test dataloader reload behavior when frequently checking validation set (via val_check_interval)"""
1251+
train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], []
1252+
1253+
class TestModel(BoringModel):
1254+
def train_dataloader(self):
1255+
train_reload_epochs.append(self.current_epoch)
1256+
return super().train_dataloader()
1257+
1258+
def val_dataloader(self):
1259+
val_reload_epochs.append(self.current_epoch)
1260+
return super().val_dataloader()
1261+
1262+
def validation_epoch_end(self, outputs):
1263+
val_check_epochs.append(self.current_epoch)
1264+
return super().validation_epoch_end(outputs)
1265+
1266+
model = TestModel()
1267+
1268+
trainer = Trainer(
1269+
default_root_dir=tmpdir,
1270+
limit_train_batches=0.3,
1271+
limit_val_batches=0.3,
1272+
val_check_interval=0.3,
1273+
reload_dataloaders_every_n_epochs=1,
12031274
max_epochs=3,
12041275
)
12051276

1206-
tracker = Mock()
1207-
model.train_dataloader = Mock(wraps=model.train_dataloader)
1208-
model.val_dataloader = Mock(wraps=model.val_dataloader)
12091277
model.test_dataloader = Mock(wraps=model.test_dataloader)
12101278

1211-
tracker.attach_mock(model.train_dataloader, "train_dataloader")
1212-
tracker.attach_mock(model.val_dataloader, "val_dataloader")
1213-
tracker.attach_mock(model.test_dataloader, "test_dataloader")
1214-
12151279
trainer.fit(model)
12161280
trainer.test(model)
12171281

1218-
# verify the sequence
1219-
expected_sequence = [call.val_dataloader()]
1220-
if n == 1:
1221-
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3
1222-
elif n == 2:
1223-
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
1224-
expected_sequence += [call.test_dataloader()]
1225-
assert tracker.mock_calls == expected_sequence
1282+
# Verify epoch of reloads
1283+
assert train_reload_epochs == [0, 1, 2]
1284+
assert val_reload_epochs == [0, 1, 2]
1285+
model.test_dataloader.assert_called_once()
1286+
1287+
# Verify validation happens 3 times per epoch + 1 for sanity check
1288+
assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2]
12261289

12271290

12281291
@pytest.mark.parametrize("n", ["test", -1])
@@ -1269,15 +1332,6 @@ def validation_step(self, batch, batch_idx):
12691332
expected_calls = [
12701333
call.train_dataloader(),
12711334
call.val_dataloader(),
1272-
# This has subsequent calls to val_dataloader
1273-
# because the training loop runs the evaluation loop,
1274-
# which reloads the val dataloader again.
1275-
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
1276-
# the val dataloader on the first epoch because this only tracks the training epoch
1277-
# meaning multiple passes through the validation data within a single training epoch
1278-
# would not have the dataloader reloaded.
1279-
# This breaks the assumption behind reload_dataloaders_every_epoch=True
1280-
call.val_dataloader(),
12811335
call.train_dataloader(),
12821336
call.val_dataloader(),
12831337
call.train_dataloader(),

0 commit comments

Comments
 (0)