Skip to content

Commit 123984d

Browse files
committed
Update tests to reflect changed reload_dataloaders_every_n_epochs behavior
1 parent 4c81836 commit 123984d

File tree

2 files changed

+27
-61
lines changed

2 files changed

+27
-61
lines changed

tests/models/test_hooks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,6 @@ def call(hook, fn, *args, **kwargs):
872872
dict(name="setup", kwargs=dict(stage="fit")),
873873
dict(name="val_dataloader"),
874874
dict(name="train_dataloader"),
875-
dict(name="val_dataloader"),
876875
dict(name="on_save_checkpoint", args=(ANY,)),
877876
dict(name="teardown", kwargs=dict(stage="fit")),
878877
]

tests/trainer/test_dataloaders.py

Lines changed: 27 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,52 +1118,6 @@ def test_dataloaders_load_only_once(tmpdir):
11181118

11191119
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]
11201120

1121-
1122-
def test_dataloaders_load_only_once_val_interval(tmpdir):
1123-
model = BoringModel()
1124-
1125-
# logger file to get meta
1126-
trainer = Trainer(
1127-
default_root_dir=tmpdir,
1128-
limit_train_batches=10,
1129-
limit_val_batches=10,
1130-
val_check_interval=0.3,
1131-
reload_dataloaders_every_n_epochs=1,
1132-
max_epochs=3,
1133-
)
1134-
1135-
tracker = Mock()
1136-
model.train_dataloader = Mock(wraps=model.train_dataloader)
1137-
model.val_dataloader = Mock(wraps=model.val_dataloader)
1138-
model.test_dataloader = Mock(wraps=model.test_dataloader)
1139-
1140-
tracker.attach_mock(model.train_dataloader, "train_dataloader")
1141-
tracker.attach_mock(model.val_dataloader, "val_dataloader")
1142-
tracker.attach_mock(model.test_dataloader, "test_dataloader")
1143-
1144-
trainer.fit(model)
1145-
trainer.test(model)
1146-
1147-
# verify the sequence
1148-
expected_sequence = [
1149-
call.val_dataloader(),
1150-
call.train_dataloader(),
1151-
call.val_dataloader(),
1152-
call.val_dataloader(),
1153-
call.val_dataloader(),
1154-
call.train_dataloader(),
1155-
call.val_dataloader(),
1156-
call.val_dataloader(),
1157-
call.val_dataloader(),
1158-
call.train_dataloader(),
1159-
call.val_dataloader(),
1160-
call.val_dataloader(),
1161-
call.val_dataloader(),
1162-
call.test_dataloader(),
1163-
]
1164-
assert tracker.mock_calls == expected_sequence
1165-
1166-
11671121
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
11681122
model = BoringModel()
11691123

@@ -1190,14 +1144,25 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
11901144

11911145
@pytest.mark.parametrize("n", [1, 2])
11921146
def test_dataloaders_load_every_n_epochs(tmpdir, n):
1193-
model = BoringModel()
1147+
train_reload_epochs, val_reload_epochs = [], []
1148+
1149+
class TestModel(BoringModel):
1150+
def train_dataloader(self):
1151+
train_reload_epochs.append(self.current_epoch)
1152+
return super().train_dataloader()
1153+
1154+
def val_dataloader(self):
1155+
val_reload_epochs.append(self.current_epoch)
1156+
return super().val_dataloader()
1157+
1158+
model = TestModel()
11941159

11951160
trainer = Trainer(
11961161
default_root_dir=tmpdir,
11971162
limit_train_batches=0.3,
11981163
limit_val_batches=0.3,
11991164
reload_dataloaders_every_n_epochs=n,
1200-
max_epochs=3,
1165+
max_epochs=5,
12011166
)
12021167

12031168
tracker = Mock()
@@ -1212,15 +1177,24 @@ def test_dataloaders_load_every_n_epochs(tmpdir, n):
12121177
trainer.fit(model)
12131178
trainer.test(model)
12141179

1215-
# verify the sequence
1216-
expected_sequence = [call.val_dataloader()]
1180+
# Verify the sequence
1181+
expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first
12171182
if n == 1:
1218-
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3
1183+
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4
12191184
elif n == 2:
12201185
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
12211186
expected_sequence += [call.test_dataloader()]
1187+
12221188
assert tracker.mock_calls == expected_sequence
12231189

1190+
# Verify epoch of reloads
1191+
if n == 1:
1192+
assert train_reload_epochs == [0, 1, 2, 3, 4]
1193+
assert val_reload_epochs == [0, 1, 2, 3, 4]
1194+
elif n == 2:
1195+
assert train_reload_epochs == [0, 2, 4]
1196+
assert val_reload_epochs == [0, 2, 4]
1197+
12241198

12251199
@pytest.mark.parametrize("n", [3, 5])
12261200
def test_dataloaders_load_every_n_epochs_infrequent_val(tmpdir, n):
@@ -1263,6 +1237,7 @@ def val_dataloader(self):
12631237

12641238
model.test_dataloader.assert_called_once()
12651239

1240+
12661241
def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):
12671242
train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], []
12681243

@@ -1303,6 +1278,7 @@ def validation_epoch_end(self, outputs):
13031278
# Verify validation happens 3 times per epoch + 1 for sanity check
13041279
assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2]
13051280

1281+
13061282
@pytest.mark.parametrize("n", ["test", -1])
13071283
def test_dataloaders_load_every_n_epochs_exception(tmpdir, n):
13081284

@@ -1347,15 +1323,6 @@ def validation_step(self, batch, batch_idx):
13471323
expected_calls = [
13481324
call.train_dataloader(),
13491325
call.val_dataloader(),
1350-
# This has subsequent calls to val_dataloader
1351-
# because the training loop runs the evaluation loop,
1352-
# which reloads the val dataloader again.
1353-
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
1354-
# the val dataloader on the first epoch because this only tracks the training epoch
1355-
# meaning multiple passes through the validation data within a single training epoch
1356-
# would not have the dataloader reloaded.
1357-
# This breaks the assumption behind reload_dataloaders_every_n_epochs=1
1358-
call.val_dataloader(),
13591326
call.train_dataloader(),
13601327
call.val_dataloader(),
13611328
call.train_dataloader(),

0 commit comments

Comments
 (0)