@@ -29,16 +29,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
29
29
30
30
"""
31
31
if trainer .state .fn in (TrainerFn .FITTING , TrainerFn .TUNING ):
32
- __verify_train_loop_configuration (trainer , model )
33
- __verify_eval_loop_configuration (model , "val" )
32
+ __verify_train_val_loop_configuration (trainer , model )
34
33
__verify_manual_optimization_support (trainer , model )
35
34
__check_training_step_requires_dataloader_iter (model )
36
35
elif trainer .state .fn == TrainerFn .VALIDATING :
37
- __verify_eval_loop_configuration (model , "val" )
36
+ __verify_eval_loop_configuration (trainer , model , "val" )
38
37
elif trainer .state .fn == TrainerFn .TESTING :
39
- __verify_eval_loop_configuration (model , "test" )
38
+ __verify_eval_loop_configuration (trainer , model , "test" )
40
39
elif trainer .state .fn == TrainerFn .PREDICTING :
41
- __verify_predict_loop_configuration (trainer , model )
40
+ __verify_eval_loop_configuration (trainer , model , "predict" )
41
+
42
42
__verify_dp_batch_transfer_support (trainer , model )
43
43
_check_add_get_queue (model )
44
44
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
@@ -51,7 +51,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
51
51
_check_dl_idx_in_on_train_batch_hooks (trainer , model )
52
52
53
53
54
- def __verify_train_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
54
+ def __verify_train_val_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
55
55
# -----------------------------------
56
56
# verify model has a training step
57
57
# -----------------------------------
@@ -83,24 +83,15 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
83
83
)
84
84
85
85
# ----------------------------------------------
86
- # verify model does not have
87
- # - on_train_dataloader
88
- # - on_val_dataloader
86
+ # verify model does not have on_train_dataloader
89
87
# ----------------------------------------------
90
88
has_on_train_dataloader = is_overridden ("on_train_dataloader" , model )
91
89
if has_on_train_dataloader :
92
90
rank_zero_deprecation (
93
- "Method `on_train_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
91
+ "Method `on_train_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
94
92
" Please use `train_dataloader()` directly."
95
93
)
96
94
97
- has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
98
- if has_on_val_dataloader :
99
- rank_zero_deprecation (
100
- "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
101
- " Please use `val_dataloader()` directly."
102
- )
103
-
104
95
trainer .overriden_optimizer_step = is_overridden ("optimizer_step" , model )
105
96
trainer .overriden_optimizer_zero_grad = is_overridden ("optimizer_zero_grad" , model )
106
97
automatic_optimization = model .automatic_optimization
@@ -110,8 +101,30 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
110
101
if has_overriden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization :
111
102
rank_zero_warn (
112
103
"When using `Trainer(accumulate_grad_batches != 1)` and overriding"
113
- "`LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
114
- "(rather, they are called on every optimization step)."
104
+ " `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
105
+ " (rather, they are called on every optimization step)."
106
+ )
107
+
108
+ # -----------------------------------
109
+ # verify model for val loop
110
+ # -----------------------------------
111
+
112
+ has_val_loader = trainer ._data_connector ._val_dataloader_source .is_defined ()
113
+ has_val_step = is_overridden ("validation_step" , model )
114
+
115
+ if has_val_loader and not has_val_step :
116
+ rank_zero_warn ("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop." )
117
+ if has_val_step and not has_val_loader :
118
+ rank_zero_warn ("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop." )
119
+
120
+ # ----------------------------------------------
121
+ # verify model does not have on_val_dataloader
122
+ # ----------------------------------------------
123
+ has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
124
+ if has_on_val_dataloader :
125
+ rank_zero_deprecation (
126
+ "Method `on_val_dataloader` is deprecated in v1.5.0 and will be removed in v1.7.0."
127
+ " Please use `val_dataloader()` directly."
115
128
)
116
129
117
130
@@ -143,52 +156,43 @@ def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
143
156
)
144
157
145
158
146
- def __verify_eval_loop_configuration (model : "pl.LightningModule" , stage : str ) -> None :
159
+ def __verify_eval_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" , stage : str ) -> None :
147
160
loader_name = f"{ stage } _dataloader"
148
- step_name = "validation_step" if stage == "val" else "test_step"
161
+ step_name = "validation_step" if stage == "val" else f"{ stage } _step"
162
+ trainer_method = "validate" if stage == "val" else stage
163
+ on_eval_hook = f"on_{ loader_name } "
149
164
150
- has_loader = is_overridden ( loader_name , model )
165
+ has_loader = getattr ( trainer . _data_connector , f"_ { stage } _dataloader_source" ). is_defined ( )
151
166
has_step = is_overridden (step_name , model )
152
-
153
- if has_loader and not has_step :
154
- rank_zero_warn (f"you passed in a { loader_name } but have no { step_name } . Skipping { stage } loop" )
155
- if has_step and not has_loader :
156
- rank_zero_warn (f"you defined a { step_name } but have no { loader_name } . Skipping { stage } loop" )
167
+ has_on_eval_dataloader = is_overridden (on_eval_hook , model )
157
168
158
169
# ----------------------------------------------
159
- # verify model does not have
160
- # - on_val_dataloader
161
- # - on_test_dataloader
170
+ # verify model does not have on_eval_dataloader
162
171
# ----------------------------------------------
163
- has_on_val_dataloader = is_overridden ("on_val_dataloader" , model )
164
- if has_on_val_dataloader :
172
+ if has_on_eval_dataloader :
165
173
rank_zero_deprecation (
166
- "Method `on_val_dataloader` in DataHooks is deprecated and will be removed in v1.7.0. "
167
- " Please use `val_dataloader ()` directly."
174
+ f "Method `{ on_eval_hook } ` is deprecated in v1.5.0 and will "
175
+ f" be removed in v1.7.0. Please use `{ loader_name } ()` directly."
168
176
)
169
177
170
- has_on_test_dataloader = is_overridden ("on_test_dataloader" , model )
171
- if has_on_test_dataloader :
172
- rank_zero_deprecation (
173
- "Method `on_test_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
174
- " Please use `test_dataloader()` directly."
175
- )
176
-
177
-
178
- def __verify_predict_loop_configuration (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
179
- has_predict_dataloader = trainer ._data_connector ._predict_dataloader_source .is_defined ()
180
- if not has_predict_dataloader :
181
- raise MisconfigurationException ("Dataloader not found for `Trainer.predict`" )
182
- # ----------------------------------------------
183
- # verify model does not have
184
- # - on_predict_dataloader
185
- # ----------------------------------------------
186
- has_on_predict_dataloader = is_overridden ("on_predict_dataloader" , model )
187
- if has_on_predict_dataloader :
188
- rank_zero_deprecation (
189
- "Method `on_predict_dataloader` in DataHooks is deprecated and will be removed in v1.7.0."
190
- " Please use `predict_dataloader()` directly."
191
- )
178
+ # -----------------------------------
179
+ # verify model has an eval_dataloader
180
+ # -----------------------------------
181
+ if not has_loader :
182
+ raise MisconfigurationException (f"No `{ loader_name } ()` method defined to run `Trainer.{ trainer_method } `." )
183
+
184
+ # predict_step is not required to be overridden
185
+ if stage == "predict" :
186
+ if model .predict_step is None :
187
+ raise MisconfigurationException ("`predict_step` cannot be None to run `Trainer.predict`" )
188
+ elif not has_step and not is_overridden ("forward" , model ):
189
+ raise MisconfigurationException ("`Trainer.predict` requires `forward` method to run." )
190
+ else :
191
+ # -----------------------------------
192
+ # verify model has an eval_step
193
+ # -----------------------------------
194
+ if not has_step :
195
+ raise MisconfigurationException (f"No `{ step_name } ()` method defined to run `Trainer.{ trainer_method } `." )
192
196
193
197
194
198
def __verify_dp_batch_transfer_support (trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> None :
0 commit comments