diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5f153034b0610..38a10caa4e957 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1748,6 +1748,155 @@ def _on_exception(self) -> None: file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") self.save_checkpoint(file_path) + """ + Data loading methods + """ + + def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the train dataloader and initialises required variables (number of batches, when to validate, + etc.). + + Args: + model: The ``LightningModule`` if calling this outside of the trainer scope. + """ + self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model) + + if self.overfit_batches > 0: + self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader) + + # automatically add samplers + self.train_dataloader = apply_to_collection( + self.train_dataloader, + DataLoader, + self._data_connector._prepare_dataloader, + shuffle=True, + mode=RunningStage.TRAINING, + ) + + # check the workers recursively + apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader") + + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) + + # add collate_fn to collect metadata for fault tolerant training + if _fault_tolerant_training(): + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) + + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) + + module = model or self.lightning_module or self.datamodule + self.num_training_batches = ( + len(self.train_dataloader) + if has_len_all_ranks(self.train_dataloader, self.strategy, module) + else float("inf") + ) + + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: + self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) + elif self.num_training_batches != float("inf"): + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + elif self.limit_train_batches != 1.0: + raise MisconfigurationException( + "When using an IterableDataset for `limit_train_batches`," + " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" + " `num_training_batches` to use." + ) + + # determine when to check validation + # if int passed in, val checks that often + # otherwise, it checks in [0, 1.0] % range of a training epoch + if isinstance(self.val_check_interval, int): + self.val_check_batch = self.val_check_interval + if self.val_check_batch > self.num_training_batches: + raise ValueError( + f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " + f"to the number of the training batches ({self.num_training_batches}). " + "If you want to disable validation set `limit_val_batches` to 0.0 instead." + ) + else: + if not has_len_all_ranks(self.train_dataloader, self.strategy, module): + if self.val_check_interval == 1.0: + self.val_check_batch = float("inf") + else: + raise MisconfigurationException( + "When using an IterableDataset for `train_dataloader`," + " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" + " checking validation every k training batches." + ) + else: + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) + self.val_check_batch = max(1, self.val_check_batch) + + if self.logger and self.num_training_batches < self.log_every_n_steps: + rank_zero_warn( + f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" + f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" + " you want to see logs for the training epoch.", + category=PossibleUserWarning, + ) + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_train_dl_reload_epoch = self.current_epoch + + def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the validation dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._val_dataloader_source + pl_module = self.lightning_module or model + has_step = is_overridden("validation_step", pl_module) + if source.is_defined() and has_step: + self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.VALIDATING, model=pl_module + ) + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_val_dl_reload_epoch = self.current_epoch + + def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the test dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._test_dataloader_source + pl_module = self.lightning_module or model + has_step = is_overridden("test_step", pl_module) + if source.is_defined() and has_step: + self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.TESTING, model=pl_module + ) + + def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the predict dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._predict_dataloader_source + pl_module = self.lightning_module or model + if source.is_defined(): + self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.PREDICTING, model=pl_module + ) + + def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + if self.train_dataloader is None: + self.reset_train_dataloader(model=model) + if self.val_dataloaders is None: + self.reset_val_dataloader(model=model) + """ Accelerator properties """ @@ -2377,151 +2526,6 @@ def terminate_on_nan(self, val: bool) -> None: ) self._terminate_on_nan = val # : 212 - def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the train dataloader and initialises required variables (number of batches, when to validate, - etc.). - - Args: - model: The ``LightningModule`` if calling this outside of the trainer scope. - """ - self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model) - - if self.overfit_batches > 0: - self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader) - - # automatically add samplers - self.train_dataloader = apply_to_collection( - self.train_dataloader, - DataLoader, - self._data_connector._prepare_dataloader, - shuffle=True, - mode=RunningStage.TRAINING, - ) - - # check the workers recursively - apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader") - - # add worker_init_fn for correct seeding in worker processes - apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) - - # add collate_fn to collect metadata for fault tolerant training - if _fault_tolerant_training(): - apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) - - # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches - self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) - - module = model or self.lightning_module or self.datamodule - self.num_training_batches = ( - len(self.train_dataloader) - if has_len_all_ranks(self.train_dataloader, self.strategy, module) - else float("inf") - ) - - if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: - self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) - elif self.num_training_batches != float("inf"): - self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) - elif self.limit_train_batches != 1.0: - raise MisconfigurationException( - "When using an IterableDataset for `limit_train_batches`," - " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" - " `num_training_batches` to use." - ) - - # determine when to check validation - # if int passed in, val checks that often - # otherwise, it checks in [0, 1.0] % range of a training epoch - if isinstance(self.val_check_interval, int): - self.val_check_batch = self.val_check_interval - if self.val_check_batch > self.num_training_batches: - raise ValueError( - f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " - f"to the number of the training batches ({self.num_training_batches}). " - "If you want to disable validation set `limit_val_batches` to 0.0 instead." - ) - else: - if not has_len_all_ranks(self.train_dataloader, self.strategy, module): - if self.val_check_interval == 1.0: - self.val_check_batch = float("inf") - else: - raise MisconfigurationException( - "When using an IterableDataset for `train_dataloader`," - " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" - " checking validation every k training batches." - ) - else: - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) - self.val_check_batch = max(1, self.val_check_batch) - - if self.logger and self.num_training_batches < self.log_every_n_steps: - rank_zero_warn( - f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" - f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" - " you want to see logs for the training epoch.", - category=PossibleUserWarning, - ) - - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - self._last_train_dl_reload_epoch = self.current_epoch - - def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the validation dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._val_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("validation_step", pl_module) - if source.is_defined() and has_step: - self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.VALIDATING, model=pl_module - ) - - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - self._last_val_dl_reload_epoch = self.current_epoch - - def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the test dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._test_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("test_step", pl_module) - if source.is_defined() and has_step: - self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.TESTING, model=pl_module - ) - - def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the predict dataloader and determines the number of batches. - - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - source = self._data_connector._predict_dataloader_source - pl_module = self.lightning_module or model - if source.is_defined(): - self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader( - RunningStage.PREDICTING, model=pl_module - ) - - def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets train and val dataloaders if none are attached to the trainer. - - The val dataloader must be initialized before training loop starts, as the training loop - inspects the val dataloader to determine whether to run the evaluation loop. - Args: - model: The ``LightningModule`` if called outside of the trainer scope. - """ - if self.train_dataloader is None: - self.reset_train_dataloader(model=model) - if self.val_dataloaders is None: - self.reset_val_dataloader(model=model) - def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: