diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b417d40484028..0376a0f745f6f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1304,10 +1304,12 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = self._data_connector._val_dataloader_source.is_defined() and is_overridden( - "validation_step", ref_model + should_sanity_check = ( + self.enable_validation + and self.num_sanity_val_steps > 0 + # do not sanity check if restarting because it would mess up the loaded state + and not self._evaluation_loop.restarting ) - should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val @@ -1780,9 +1782,11 @@ def _should_reload_dl_epoch(self) -> bool: @property def enable_validation(self) -> bool: """Check if we should run validation during training.""" - model_ref = self.lightning_module - val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0 - return val_loop_enabled + return ( + self._data_connector._val_dataloader_source.is_defined() + and is_overridden("validation_step", self.lightning_module) + and self.limit_val_batches > 0 + ) @property def default_root_dir(self) -> str: