Skip to content

Commit 07e5bd3

Browse files
rohitgr7ninginthecloud
authored andcommitted
Raise MisconfigurationException if trainer.eval is missing required methods (Lightning-AI#10016)
1 parent 2051277 commit 07e5bd3

File tree

9 files changed

+117
-101
lines changed

9 files changed

+117
-101
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
340340
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
341341

342342

343+
- Raise `MisconfigurationException` instead of warning if `trainer.{validate/test}` is missing required methods ([#10016](https://github.com/PyTorchLightning/pytorch-lightning/pull/10016))
344+
345+
343346
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
344347

345348

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
2929
3030
"""
3131
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
32-
__verify_train_loop_configuration(trainer, model)
33-
__verify_eval_loop_configuration(model, "val")
32+
__verify_train_val_loop_configuration(trainer, model)
3433
__verify_manual_optimization_support(trainer, model)
3534
__check_training_step_requires_dataloader_iter(model)
3635
elif trainer.state.fn == TrainerFn.VALIDATING:
37-
__verify_eval_loop_configuration(model, "val")
36+
__verify_eval_loop_configuration(trainer, model, "val")
3837
elif trainer.state.fn == TrainerFn.TESTING:
39-
__verify_eval_loop_configuration(model, "test")
38+
__verify_eval_loop_configuration(trainer, model, "test")
4039
elif trainer.state.fn == TrainerFn.PREDICTING:
41-
__verify_predict_loop_configuration(trainer, model)
40+
__verify_eval_loop_configuration(trainer, model, "predict")
41+
4242
__verify_dp_batch_transfer_support(trainer, model)
4343
_check_add_get_queue(model)
4444
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
@@ -51,7 +51,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
5151
_check_dl_idx_in_on_train_batch_hooks(trainer, model)
5252

5353

54-
def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
54+
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
5555
# -----------------------------------
5656
# verify model has a training step
5757
# -----------------------------------
@@ -83,24 +83,15 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
8383
)
8484

8585
# ----------------------------------------------
86-
# verify model does not have
87-
# - on_train_dataloader
88-
# - on_val_dataloader
86+
# verify model does not have on_train_dataloader
8987
# ----------------------------------------------
9088
has_on_train_dataloader = is_overridden("on_train_dataloader", model)
9189
if has_on_train_dataloader:
9290
rank_zero_deprecation(
93-
"Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
91+
"Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
9492
" Please use `train_dataloader()` directly."
9593
)
9694

97-
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
98-
if has_on_val_dataloader:
99-
rank_zero_deprecation(
100-
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
101-
" Please use `val_dataloader()` directly."
102-
)
103-
10495
trainer.overriden_optimizer_step = is_overridden("optimizer_step", model)
10596
trainer.overriden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model)
10697
automatic_optimization = model.automatic_optimization
@@ -110,8 +101,30 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
110101
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization:
111102
rank_zero_warn(
112103
"When using `Trainer(accumulate_grad_batches != 1)` and overriding"
113-
"`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
114-
"(rather, they are called on every optimization step)."
104+
" `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
105+
" (rather, they are called on every optimization step)."
106+
)
107+
108+
# -----------------------------------
109+
# verify model for val loop
110+
# -----------------------------------
111+
112+
has_val_loader = trainer._data_connector._val_dataloader_source.is_defined()
113+
has_val_step = is_overridden("validation_step", model)
114+
115+
if has_val_loader and not has_val_step:
116+
rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
117+
if has_val_step and not has_val_loader:
118+
rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
119+
120+
# ----------------------------------------------
121+
# verify model does not have on_val_dataloader
122+
# ----------------------------------------------
123+
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
124+
if has_on_val_dataloader:
125+
rank_zero_deprecation(
126+
"Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
127+
" Please use `val_dataloader()` directly."
115128
)
116129

117130

@@ -143,52 +156,43 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
143156
)
144157

145158

146-
def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None:
159+
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None:
147160
loader_name = f"{stage}_dataloader"
148-
step_name = "validation_step" if stage == "val" else "test_step"
161+
step_name = "validation_step" if stage == "val" else f"{stage}_step"
162+
trainer_method = "validate" if stage == "val" else stage
163+
on_eval_hook = f"on_{loader_name}"
149164

150-
has_loader = is_overridden(loader_name, model)
165+
has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined()
151166
has_step = is_overridden(step_name, model)
152-
153-
if has_loader and not has_step:
154-
rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")
155-
if has_step and not has_loader:
156-
rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
167+
has_on_eval_dataloader = is_overridden(on_eval_hook, model)
157168

158169
# ----------------------------------------------
159-
# verify model does not have
160-
# - on_val_dataloader
161-
# - on_test_dataloader
170+
# verify model does not have on_eval_dataloader
162171
# ----------------------------------------------
163-
has_on_val_dataloader = is_overridden("on_val_dataloader", model)
164-
if has_on_val_dataloader:
172+
if has_on_eval_dataloader:
165173
rank_zero_deprecation(
166-
"Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
167-
" Please use `val_dataloader()` directly."
174+
f"Method `{on_eval_hook}` is deprecated in v1.5.0 and will"
175+
f" be removed in v1.7.0. Please use `{loader_name}()` directly."
168176
)
169177

170-
has_on_test_dataloader = is_overridden("on_test_dataloader", model)
171-
if has_on_test_dataloader:
172-
rank_zero_deprecation(
173-
"Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
174-
" Please use `test_dataloader()` directly."
175-
)
176-
177-
178-
def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
179-
has_predict_dataloader = trainer._data_connector._predict_dataloader_source.is_defined()
180-
if not has_predict_dataloader:
181-
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
182-
# ----------------------------------------------
183-
# verify model does not have
184-
# - on_predict_dataloader
185-
# ----------------------------------------------
186-
has_on_predict_dataloader = is_overridden("on_predict_dataloader", model)
187-
if has_on_predict_dataloader:
188-
rank_zero_deprecation(
189-
"Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
190-
" Please use `predict_dataloader()` directly."
191-
)
178+
# -----------------------------------
179+
# verify model has an eval_dataloader
180+
# -----------------------------------
181+
if not has_loader:
182+
raise MisconfigurationException(f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`.")
183+
184+
# predict_step is not required to be overridden
185+
if stage == "predict":
186+
if model.predict_step is None:
187+
raise MisconfigurationException("`predict_step` cannot be None to run `Trainer.predict`")
188+
elif not has_step and not is_overridden("forward", model):
189+
raise MisconfigurationException("`Trainer.predict` requires `forward` method to run.")
190+
else:
191+
# -----------------------------------
192+
# verify model has an eval_step
193+
# -----------------------------------
194+
if not has_step:
195+
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")
192196

193197

194198
def __verify_dp_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:

tests/callbacks/test_pruning.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131

3232
class TestModel(BoringModel):
33-
test_step = None
34-
3533
def __init__(self):
3634
super().__init__()
3735
self.layer = Sequential(

tests/callbacks/test_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def test_quantization_val_test_predict(tmpdir):
224224
max_epochs=4,
225225
)
226226
trainer.fit(val_test_predict_qmodel, datamodule=dm)
227-
trainer.validate(model=val_test_predict_qmodel, verbose=False)
228-
trainer.test(model=val_test_predict_qmodel, verbose=False)
227+
trainer.validate(model=val_test_predict_qmodel, datamodule=dm, verbose=False)
228+
trainer.test(model=val_test_predict_qmodel, datamodule=dm, verbose=False)
229229
trainer.predict(
230230
model=val_test_predict_qmodel, dataloaders=[torch.utils.data.DataLoader(RandomDataset(num_features, 16))]
231231
)

tests/deprecated_api/test_remove_1-7.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,27 @@ def _run(model, task="fit"):
163163
model = CustomBoringModel()
164164

165165
with pytest.deprecated_call(
166-
match="Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
166+
match="Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
167167
):
168168
_run(model, "fit")
169169

170170
with pytest.deprecated_call(
171-
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
171+
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
172172
):
173173
_run(model, "fit")
174174

175175
with pytest.deprecated_call(
176-
match="Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
176+
match="Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
177177
):
178178
_run(model, "validate")
179179

180180
with pytest.deprecated_call(
181-
match="Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
181+
match="Method `on_test_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
182182
):
183183
_run(model, "test")
184184

185185
with pytest.deprecated_call(
186-
match="Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
186+
match="Method `on_predict_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
187187
):
188188
_run(model, "predict")
189189

tests/helpers/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def test_models(tmpdir, data_class, model_class):
4040
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
4141

4242
trainer.fit(model, datamodule=dm)
43-
trainer.test(model, datamodule=dm)
43+
44+
if dm is not None:
45+
trainer.test(model, datamodule=dm)
4446

4547
model.to_torchscript()
4648
if data_class:

tests/models/test_restore.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
331331

332332
# correct result and ok accuracy
333333
assert trainer.state.finished, f"Training failed with {trainer.state}"
334-
pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
334+
pretrained_model = CustomClassificationModelDP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
335335

336336
# run test set
337337
new_trainer = Trainer(**trainer_options)
338-
new_trainer.test(pretrained_model)
338+
new_trainer.test(pretrained_model, datamodule=dm)
339339
pretrained_model.cpu()
340340

341341
dataloaders = dm.test_dataloader()
@@ -383,7 +383,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
383383

384384
# run test set
385385
new_trainer = Trainer(**trainer_options)
386-
new_trainer.test(pretrained_model)
386+
new_trainer.test(pretrained_model, datamodule=dm)
387387
pretrained_model.cpu()
388388

389389
dataloaders = dm.test_dataloader()

tests/plugins/test_deepspeed_plugin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,9 @@ def test_dataloader(self):
915915
gpus=1,
916916
fast_dev_run=True,
917917
)
918-
trainer.fit(model, datamodule=TestSetupIsCalledDataModule())
919-
trainer.test(model)
918+
dm = TestSetupIsCalledDataModule()
919+
trainer.fit(model, datamodule=dm)
920+
trainer.test(model, datamodule=dm)
920921

921922

922923
@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)

tests/trainer/test_config_validator.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,50 +52,63 @@ def test_fit_val_loop_config(tmpdir):
5252
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
5353

5454
# no val data has val loop
55-
with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"):
55+
with pytest.warns(UserWarning, match=r"You passed in a `val_dataloader` but have no `validation_step`"):
5656
model = BoringModel()
5757
model.validation_step = None
5858
trainer.fit(model)
5959

6060
# has val loop but no val data
61-
with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"):
61+
with pytest.warns(UserWarning, match=r"You defined a `validation_step` but have no `val_dataloader`"):
6262
model = BoringModel()
6363
model.val_dataloader = None
6464
trainer.fit(model)
6565

6666

67-
def test_test_loop_config(tmpdir):
68-
"""When either test loop or test data are missing."""
67+
def test_eval_loop_config(tmpdir):
68+
"""When either eval step or eval data is missing."""
6969
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
7070

71+
# has val step but no val data
72+
model = BoringModel()
73+
model.val_dataloader = None
74+
with pytest.raises(MisconfigurationException, match=r"No `val_dataloader\(\)` method defined"):
75+
trainer.validate(model)
76+
77+
# has test data but no val step
78+
model = BoringModel()
79+
model.validation_step = None
80+
with pytest.raises(MisconfigurationException, match=r"No `validation_step\(\)` method defined"):
81+
trainer.validate(model)
82+
7183
# has test loop but no test data
72-
with pytest.warns(UserWarning, match=r"you defined a test_step but have no test_dataloader"):
73-
model = BoringModel()
74-
model.test_dataloader = None
84+
model = BoringModel()
85+
model.test_dataloader = None
86+
with pytest.raises(MisconfigurationException, match=r"No `test_dataloader\(\)` method defined"):
7587
trainer.test(model)
7688

77-
# has test data but no test loop
78-
with pytest.warns(UserWarning, match=r"you passed in a test_dataloader but have no test_step"):
79-
model = BoringModel()
80-
model.test_step = None
89+
# has test data but no test step
90+
model = BoringModel()
91+
model.test_step = None
92+
with pytest.raises(MisconfigurationException, match=r"No `test_step\(\)` method defined"):
8193
trainer.test(model)
8294

95+
# has predict step but no predict data
96+
model = BoringModel()
97+
model.predict_dataloader = None
98+
with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"):
99+
trainer.predict(model)
83100

84-
def test_val_loop_config(tmpdir):
85-
"""When either validation loop or validation data are missing."""
86-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
87-
88-
# has val loop but no val data
89-
with pytest.warns(UserWarning, match=r"you defined a validation_step but have no val_dataloader"):
90-
model = BoringModel()
91-
model.val_dataloader = None
92-
trainer.validate(model)
101+
# has predict data but no predict_step
102+
model = BoringModel()
103+
model.predict_step = None
104+
with pytest.raises(MisconfigurationException, match=r"`predict_step` cannot be None."):
105+
trainer.predict(model)
93106

94-
# has val data but no val loop
95-
with pytest.warns(UserWarning, match=r"you passed in a val_dataloader but have no validation_step"):
96-
model = BoringModel()
97-
model.validation_step = None
98-
trainer.validate(model)
107+
# has predict data but no forward
108+
model = BoringModel()
109+
model.forward = None
110+
with pytest.raises(MisconfigurationException, match=r"requires `forward` method to run."):
111+
trainer.predict(model)
99112

100113

101114
@pytest.mark.parametrize("datamodule", [False, True])
@@ -130,11 +143,6 @@ def predict_dataloader(self):
130143
assert len(results) == 2
131144
assert results[0][0].shape == torch.Size([1, 2])
132145

133-
model.predict_dataloader = None
134-
135-
with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
136-
trainer.predict(model)
137-
138146

139147
def test_trainer_manual_optimization_config(tmpdir):
140148
"""Test error message when requesting Trainer features unsupported with manual optimization."""

0 commit comments

Comments
 (0)