Skip to content

Commit b530b7a

Browse files
authored
update tests to not rely on patched dataloaders (#9905)
1 parent 98c0a11 commit b530b7a

File tree

6 files changed

+42
-25
lines changed

6 files changed

+42
-25
lines changed

tests/callbacks/test_early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
9595
)
9696

9797
with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
98-
new_trainer.fit(model)
98+
new_trainer.fit(model, datamodule=dm)
9999

100100

101101
def test_early_stopping_no_extraneous_invocations(tmpdir):

tests/models/test_restore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
340340
new_trainer.test(pretrained_model)
341341
pretrained_model.cpu()
342342

343-
dataloaders = model.test_dataloader()
343+
dataloaders = dm.test_dataloader()
344344
if not isinstance(dataloaders, list):
345345
dataloaders = [dataloaders]
346346

@@ -539,7 +539,7 @@ def on_pretrain_routine_end(self):
539539
# haven't trained with the new loaded model
540540
new_trainer.state.stage = RunningStage.VALIDATING
541541

542-
dataloader = self.train_dataloader()
542+
dataloader = dm.train_dataloader()
543543
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
544544
self.on_pretrain_routine_end_called = True
545545

tests/trainer/test_data_loading.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,19 +267,19 @@ def test_loader_detaching():
267267

268268
class LoaderTestModel(BoringModel):
269269
def training_step(self, batch, batch_idx):
270-
assert len(model.train_dataloader()) == 10
270+
assert len(self.trainer.train_dataloader.loaders) == 10
271271
return super().training_step(batch, batch_idx)
272272

273273
def validation_step(self, batch, batch_idx):
274-
assert len(model.val_dataloader()) == 10
274+
assert len(self.trainer.val_dataloaders[0]) == 10
275275
return super().validation_step(batch, batch_idx)
276276

277277
def test_step(self, batch, batch_idx):
278-
assert len(model.test_dataloader()) == 10
278+
assert len(self.trainer.test_dataloaders[0]) == 10
279279
return super().test_step(batch, batch_idx)
280280

281281
def predict_step(self, batch, batch_idx, dataloader_idx=None):
282-
assert len(model.predict_dataloader()) == 10
282+
assert len(self.trainer.predict_dataloaders[0]) == 10
283283
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
284284

285285
loader = DataLoader(RandomDataset(32, 10), batch_size=1)

tests/trainer/test_dataloaders.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
184184
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
185185
model.test_step = model.test_step__multiple_dataloaders
186186

187-
# train, multiple val and multiple test passed to fit
187+
# multiple val dataloaders passed to fit
188188
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
189189
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)
190190

@@ -195,10 +195,10 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
195195
ckpt_path = trainer.checkpoint_callback.best_model_path
196196

197197
trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
198-
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
198+
assert len(trainer.test_dataloaders) == n
199199

200+
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
200201
assert len(trainer.val_dataloaders) == n
201-
assert len(trainer.test_dataloaders) == n
202202

203203

204204
class DummyModel(BoringModel):
@@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
551551
# fit model
552552
trainer = Trainer(**trainer_options)
553553
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
554-
assert trainer.state.finished, f"Training failed with {trainer.state}"
555554

556555
# fit model
557556
trainer = Trainer(**trainer_options)
558557
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
559-
assert trainer.state.finished, f"Training failed with {trainer.state}"
558+
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
559+
560560
if ckpt_path == "specific":
561561
ckpt_path = trainer.checkpoint_callback.best_model_path
562562
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)
563-
564-
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
565563
assert (
566564
len(trainer.test_dataloaders) == 1
567565
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
@@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
13131311

13141312

13151313
def test_dataloaders_reset_and_attach(tmpdir):
1316-
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
1317-
attaching the new one."""
1314+
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
1315+
the new one."""
13181316
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
13191317
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
13201318
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))

tests/trainer/test_trainer_tricks.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir):
8484
# test train loader applies correct limits
8585
# ------------------------------------------------------
8686
trainer = Trainer(overfit_batches=4)
87+
trainer.data_connector.attach_dataloaders(model=model)
8788
trainer.reset_train_dataloader(model)
8889
assert trainer.num_training_batches == 4
8990

@@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir):
9394
assert torch.eq(ya, yb).all()
9495

9596
trainer = Trainer(overfit_batches=0.11)
97+
trainer.data_connector.attach_dataloaders(model=model)
9698
trainer.reset_train_dataloader(model)
9799
# The dataloader should have been overwritten with a Sequential sampler.
98100
assert trainer.train_dataloader is not train_loader
@@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir):
111113
# ------------------------------------------------------
112114
# test overfit_batches as percent
113115
# ------------------------------------------------------
114-
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model)
116+
trainer = Trainer(overfit_batches=0.11)
117+
trainer.data_connector.attach_dataloaders(model)
118+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
115119
assert loader_num_batches[0] == num_train_samples
116120

117121
# make sure we turned off shuffle for the user
@@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir):
125129
# ------------------------------------------------------
126130
# test overfit_batches as int
127131
# ------------------------------------------------------
128-
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model)
132+
trainer = Trainer(overfit_batches=1)
133+
trainer.data_connector.attach_dataloaders(model)
134+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
129135
assert loader_num_batches[0] == 1
130-
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model)
136+
trainer = Trainer(overfit_batches=5)
137+
trainer.data_connector.attach_dataloaders(model)
138+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
131139
assert loader_num_batches[0] == 5
132140

133141
# ------------------------------------------------------
134142
# test limit_xxx_batches as percent AND int
135143
# ------------------------------------------------------
136144
if split == RunningStage.VALIDATING:
137-
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
145+
trainer = Trainer(limit_val_batches=0.1)
146+
trainer.data_connector.attach_dataloaders(model)
147+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
138148
assert loader_num_batches[0] == int(0.1 * len(val_loader))
139149

140-
loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model)
150+
trainer = Trainer(limit_val_batches=10)
151+
trainer.data_connector.attach_dataloaders(model)
152+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
141153
assert loader_num_batches[0] == 10
142154
else:
143-
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model)
155+
trainer = Trainer(limit_test_batches=0.1)
156+
trainer.data_connector.attach_dataloaders(model)
157+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
144158
assert loader_num_batches[0] == int(0.1 * len(test_loader))
145159

146-
loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model)
160+
trainer = Trainer(limit_test_batches=10)
161+
trainer.data_connector.attach_dataloaders(model)
162+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
147163
assert loader_num_batches[0] == 10

tests/tuner/test_scale_batch_size.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
220220
limit_train_batches=0.2,
221221
auto_scale_batch_size="power",
222222
)
223-
fit_options = dict(train_dataloader=model.dataloader(train=True))
223+
fit_options = dict(train_dataloaders=model.dataloader(train=True))
224224

225-
with pytest.raises(MisconfigurationException):
225+
with pytest.raises(
226+
MisconfigurationException,
227+
match="The batch scaling feature cannot be used with dataloaders passed directly",
228+
):
226229
trainer.tune(model, **fit_options)
227230

228231

0 commit comments

Comments
 (0)