diff --git a/CHANGELOG.md b/CHANGELOG.md index 486ef208c591a..1256865dfa669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148)) +- Deprecated `TrainerDataLoadingMixin` and moved functionality to `Trainer` and `DataConnector` ([#11282](https://github.com/PyTorchLightning/pytorch-lightning/pull/11282)) + + - Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index fa3cc1233b2a1..761852d80d1cb 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None: """Reloads dataloaders if necessary.""" if self.trainer.testing: self.trainer.reset_test_dataloader() - elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl: + elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl: self.trainer.reset_val_dataloader() def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 09493f57a59e7..b1e675d52ae5b 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override] model = self.trainer.lightning_module # reset train dataloader - if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl: + if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False @@ -223,6 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override] def advance(self) -> None: # type: ignore[override] """Runs one whole epoch.""" + assert self.trainer.train_dataloader is not None dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) diff --git a/pytorch_lightning/strategies/ipu.py b/pytorch_lightning/strategies/ipu.py index bb36a928113a3..1791d58e490ca 100644 --- a/pytorch_lightning/strategies/ipu.py +++ b/pytorch_lightning/strategies/ipu.py @@ -121,8 +121,8 @@ def setup(self, trainer: "pl.Trainer") -> None: # patch the dataloader creation function with the custom `poptorch.DataLoader`. # this violates the intended control flow for the plugins, but since this is experimental, we have chosen # to use the simpler solution before adding abstractions to override the `DataLoader` class - self._update_dataloader_original = pl.trainer.data_loading._update_dataloader - pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader + self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader + pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader super().setup(trainer) @@ -278,7 +278,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: def teardown(self) -> None: super().teardown() # undo dataloader patching - pl.trainer.data_loading._update_dataloader = self._update_dataloader_original + pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original for model in self.poptorch_models.values(): model.destroy() diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index eae699081f141..2b17271c8c918 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -11,15 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os from dataclasses import dataclass from functools import partial -from typing import Iterable, Optional, Union +from typing import Any, Collection, Iterable, List, Optional, Tuple, Union from weakref import proxy +from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler +from torch.utils.data.dataset import IterableDataset +from torch.utils.data.distributed import DistributedSampler + import pytorch_lightning as pl +from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.auto_restart import ( + _teardown_dataloader_get_iterators, + _validate_fault_tolerant_automatic, +) +from pytorch_lightning.utilities.data import ( + _auto_add_worker_init_fn, + _replace_dataloader_init_method, + _update_dataloader, + has_iterable_dataset, + has_len_all_ranks, +) +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -27,10 +47,11 @@ DataLoaderIterDataFetcher, InterBatchParallelDataFetcher, ) +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from pytorch_lightning.utilities.warnings import rank_zero_warn +from pytorch_lightning.utilities.warnings import PossibleUserWarning, rank_zero_warn class DataConnector: @@ -61,6 +82,18 @@ def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: return self.sanity_check_data_fetcher return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher + @property + def _should_reload_train_dl(self) -> bool: + """Check if train dataloader should be reloaded.""" + n_epochs = self.trainer.reload_dataloaders_every_n_epochs + return n_epochs and (self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs) + + @property + def _should_reload_val_dl(self) -> bool: + """Check if validation dataloader should be reloaded.""" + n_epochs = self.trainer.reload_dataloaders_every_n_epochs + return n_epochs and (self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs) + def on_trainer_init( self, check_val_every_n_epoch: int, @@ -242,6 +275,277 @@ def attach_datamodule( if hasattr(datamodule, "data_pipeline"): model.data_pipeline = datamodule.data_pipeline + def _worker_check(self, dataloader: DataLoader, name: str) -> None: + if not isinstance(dataloader, DataLoader): + return + + using_spawn = self.trainer._accelerator_connector._distrib_type == _StrategyType.DDP_SPAWN + num_cpus = multiprocessing.cpu_count() + + # ddp_spawn + num_workers > 0 don't mix! tell the user + if dataloader.num_workers > 0 and using_spawn: + # checks for the attr persistent_workers available in pytorch >= 1.7 + if hasattr(dataloader, "persistent_workers"): + if not dataloader.persistent_workers: + rank_zero_warn( + "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" + " may result in data loading bottlenecks." + " Consider setting persistent_workers=True" + " (this is a limitation of Python .spawn() and PyTorch)" + ) + else: + rank_zero_warn( + "num_workers>0 and strategy=ddp_spawn do not mix well" + " and may result in data loading bottlenecks." + " Consider setting strategy=ddp to use num_workers>0" + " (this is a limitation of Python .spawn() and PyTorch)" + ) + + elif dataloader.num_workers == 0 and using_spawn: + # checks for the attr persistent_workers available in pytorch >= 1.7 + if hasattr(dataloader, "persistent_workers"): + if not dataloader.persistent_workers: + rank_zero_warn( + "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." + " Consider setting num_workers>0 and persistent_workers=True" + ) + else: + rank_zero_warn( + "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." + " Consider setting strategy=ddp and set num_workers>0" + ) + + elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: + # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' + rank_zero_warn( + f"The dataloader, {name}, does not have many workers which may be a bottleneck." + " Consider increasing the value of the `num_workers` argument`" + f" (try {num_cpus} which is the number of cpus on this machine)" + " in the `DataLoader` init to improve performance.", + category=PossibleUserWarning, + ) + + def _requires_distributed_sampler(self, dataloader) -> bool: + return ( + self.trainer._accelerator_connector.replace_sampler_ddp + and self.trainer._accelerator_connector.is_distributed + and not isinstance(dataloader.sampler, DistributedSampler) + and not has_iterable_dataset(dataloader) + ) + + def _prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: + """This function handles to following functionalities: + + - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment + - Wrapping the datasets and samplers into fault-tolerant components + """ + if isinstance(dataloader, CombinedLoader): + # apply `_prepare_dataloader` on all the collection of loaders + dataloader.loaders = apply_to_collection( + dataloader.loaders, (DataLoader, CycleIterator), self._prepare_dataloader, shuffle, mode=mode + ) + # the length need to recomputed across all dataloaders in case of special behavior. + dataloader._apply_cycle_iterator_length() + return dataloader + + # don't do anything if it's not a dataloader + if not isinstance(dataloader, (DataLoader, CycleIterator)): + return dataloader + + cycle_iterator: Optional[CycleIterator] = None + + if isinstance(dataloader, CycleIterator): + cycle_iterator = dataloader + dataloader = dataloader.loader + + if ( + _fault_tolerant_training() # injects components to track the state + or self._requires_distributed_sampler(dataloader) # sets the distributed sampler + or mode == RunningStage.PREDICTING # to track indices for the predictions + or self.trainer._accelerator_connector.use_ipu # IPUs use a custom `DataLoader` + ): + sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) + dataloader = _update_dataloader(dataloader, sampler, mode=mode) + + if cycle_iterator is not None: + cycle_iterator.loader = dataloader + return cycle_iterator + + return dataloader + + def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: + if self._requires_distributed_sampler(dataloader): + if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): + raise MisconfigurationException( + "You seem to have configured a sampler in your DataLoader. This will be replaced" + " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" + " distributed training. Either remove the sampler from your DataLoader or set" + " `replace_sampler_ddp=False` if you want to use your custom sampler." + ) + return self._get_distributed_sampler( + dataloader, + shuffle, + mode=mode, + overfit_batches=self.trainer.overfit_batches, + **self.trainer.distributed_sampler_kwargs, + ) + + return dataloader.sampler + + @staticmethod + def _get_distributed_sampler( + dataloader: DataLoader, + shuffle: bool, + overfit_batches: Union[int, float], + mode: Optional[RunningStage] = None, + **kwargs: Any, + ) -> DistributedSampler: + """This function is used to created the distributed sampler injected within the user DataLoader.""" + kwargs["shuffle"] = shuffle and not overfit_batches + kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DistributedSampler + sampler = cls(dataloader.dataset, **kwargs) + return sampler + + def _reset_eval_dataloader( + self, mode: RunningStage, model: Optional["pl.LightningModule"] = None + ) -> Tuple[List[Union[int, float]], List[DataLoader]]: + """Generic method to reset a dataloader for evaluation. + + Args: + mode: The running stage of the ``Trainer`` + model: The ``LightningModule`` if calling this outside of the trainer scope. + + Returns: + Tuple (num_batches, dataloaders) + """ + assert mode.evaluating or mode == RunningStage.PREDICTING + + # always get the loaders first so we can count how many there are + dataloaders = self._request_dataloader(mode, model=model) + + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + if any(dl is None for dl in dataloaders): + rank_zero_warn("One of given dataloaders is None and it will be skipped.") + + for loader in dataloaders: + apply_to_collection( + loader.loaders if isinstance(loader, CombinedLoader) else loader, + DataLoader, + self._check_eval_shuffling, + mode=mode, + ) + + # add samplers + dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] + + # add worker_init_fn for correct seeding in worker processes + apply_to_collection( + dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank + ) + + loader_num_batches = [] + + # determine number of batches + # datasets could be none, 1 or 2+ + module = model or self.trainer.lightning_module or self.datamodule + if len(dataloaders) != 0: + for i, dataloader in enumerate(dataloaders): + orig_num_batches = num_batches = ( + len(dataloader) if has_len_all_ranks(dataloader, self.trainer.strategy, module) else float("inf") + ) + self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") + + # percent or num_steps + limit_eval_batches = getattr(self.trainer, f"limit_{mode.dataloader_prefix}_batches") + + # limit num batches either as a percent or num steps + if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: + num_batches = min(num_batches, int(limit_eval_batches)) + elif num_batches != float("inf"): + num_batches = int(num_batches * limit_eval_batches) + elif limit_eval_batches != 1.0: + raise MisconfigurationException( + f"When using an IterableDataset for `limit_{mode}_batches`," + f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" + f" specifies `num_{mode.dataloader_prefix}_batches` to use." + ) + + if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): + min_pct = 1.0 / len(dataloader) + raise MisconfigurationException( + f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" + f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the" + f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" + f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" + ) + + loader_num_batches.append(num_batches) + + return loader_num_batches, dataloaders + + def _request_dataloader( + self, stage: RunningStage, model: Optional["pl.LightningModule"] = None + ) -> Union[DataLoader, List[DataLoader]]: + """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. + + Returns: + The requested dataloader + """ + source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") + + hook = f"{stage.dataloader_prefix}_dataloader" + self.trainer._call_lightning_module_hook("on_" + hook, pl_module=model) + with _replace_dataloader_init_method(): + # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as + # attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning + dataloader = source.dataloader() + if isinstance(dataloader, tuple): + dataloader = list(dataloader) + self.trainer.strategy.barrier("get_dataloaders") + _validate_fault_tolerant_automatic(dataloader, stage) + return dataloader + + @staticmethod + def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: + all_have_sequential_sampler = True + + def resolve_has_no_sequential_sampler(dataloader: DataLoader): + nonlocal all_have_sequential_sampler + all_have_sequential_sampler = all_have_sequential_sampler & isinstance( + dataloader.sampler, SequentialSampler + ) + + apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) + + if not all_have_sequential_sampler: + rank_zero_warn( + "You requested to overfit but enabled training dataloader shuffling." + " We are turning off the training dataloader shuffling for you." + ) + + def replace_sampler(dataloader: DataLoader) -> DataLoader: + return _update_dataloader(dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING) + + dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) + + return dataloader + + @staticmethod + def _check_eval_shuffling(dataloader, mode): + if ( + hasattr(dataloader, "sampler") + and not isinstance(dataloader.sampler, SequentialSampler) + and not isinstance(dataloader.dataset, IterableDataset) + ): + rank_zero_warn( + f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," + " it is strongly recommended that you turn this off for val/test/predict dataloaders.", + category=PossibleUserWarning, + ) + def teardown(self) -> None: if self.train_data_fetcher: self.train_data_fetcher.teardown() diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 76dc4ce989c66..99a92d5af9dd1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -11,483 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing -import os from abc import ABC -from typing import Any, Callable, Collection, List, Optional, Tuple, Union +from typing import Any, List, Optional, Union -from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler -from torch.utils.data.dataset import IterableDataset -from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler -from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic -from pytorch_lightning.utilities.data import ( - _auto_add_worker_init_fn, - _replace_dataloader_init_method, - _update_dataloader, - has_iterable_dataset, - has_len_all_ranks, -) -from pytorch_lightning.utilities.enums import _StrategyType -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.warnings import PossibleUserWarning +from pytorch_lightning.utilities import rank_zero_deprecation class TrainerDataLoadingMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - val_check_interval: float - reload_dataloaders_every_n_epochs: int - tpu_local_core_rank: int - train_dataloader: DataLoader - limit_train_batches: Union[int, float] - num_training_batches: int - val_check_batch: float - val_dataloaders: List[DataLoader] - limit_val_batches: Union[int, float] - num_val_batches: List[int] - test_dataloaders: List[DataLoader] - limit_test_batches: Union[int, float] - num_test_batches: List[int] - predict_dataloaders: List[DataLoader] - limit_predict_batches: Union[int, float] - num_predict_batches: List[int] - log_every_n_steps: int - overfit_batches: Union[int, float] - distributed_sampler_kwargs: dict - accelerator: Accelerator - call_hook: Callable - current_epoch: int - _accelerator_connector: AcceleratorConnector - _last_train_dl_reload_epoch: int - _last_val_dl_reload_epoch: int - - @property - def _should_reload_train_dl(self) -> bool: - """Check if train dataloader should be reloaded.""" - n_epochs = self.reload_dataloaders_every_n_epochs - return n_epochs and (self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs) - - @property - def _should_reload_val_dl(self) -> bool: - """Check if validation dataloader should be reloaded.""" - n_epochs = self.reload_dataloaders_every_n_epochs - return n_epochs and (self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs) - - def _worker_check(self, dataloader: DataLoader, name: str) -> None: - if not isinstance(dataloader, DataLoader): - return - - using_spawn = self._accelerator_connector._distrib_type == _StrategyType.DDP_SPAWN - num_cpus = multiprocessing.cpu_count() - - # ddp_spawn + num_workers > 0 don't mix! tell the user - if dataloader.num_workers > 0 and using_spawn: - # checks for the attr persistent_workers available in pytorch >= 1.7 - if hasattr(dataloader, "persistent_workers"): - if not dataloader.persistent_workers: - rank_zero_warn( - "num_workers>0, persistent_workers=False, and strategy=ddp_spawn" - " may result in data loading bottlenecks." - " Consider setting persistent_workers=True" - " (this is a limitation of Python .spawn() and PyTorch)" - ) - else: - rank_zero_warn( - "num_workers>0 and strategy=ddp_spawn do not mix well" - " and may result in data loading bottlenecks." - " Consider setting strategy=ddp to use num_workers>0" - " (this is a limitation of Python .spawn() and PyTorch)" - ) - - elif dataloader.num_workers == 0 and using_spawn: - # checks for the attr persistent_workers available in pytorch >= 1.7 - if hasattr(dataloader, "persistent_workers"): - if not dataloader.persistent_workers: - rank_zero_warn( - "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." - " Consider setting num_workers>0 and persistent_workers=True" - ) - else: - rank_zero_warn( - "strategy=ddp_spawn and num_workers=0 may result in data loading bottlenecks." - " Consider setting strategy=ddp and set num_workers>0" - ) - - elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: - # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' - rank_zero_warn( - f"The dataloader, {name}, does not have many workers which may be a bottleneck." - " Consider increasing the value of the `num_workers` argument`" - f" (try {num_cpus} which is the number of cpus on this machine)" - " in the `DataLoader` init to improve performance.", - category=PossibleUserWarning, - ) - - def _requires_distributed_sampler(self, dataloader) -> bool: - return ( - self._accelerator_connector.replace_sampler_ddp - and self._accelerator_connector.is_distributed - and not isinstance(dataloader.sampler, DistributedSampler) - and not has_iterable_dataset(dataloader) - ) + r""" + .. deprecated:: v1.6 + The `TrainerDataLoadingMixin` class was deprecated in v1.6 and will be removed in v1.8. + """ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: - """This function handles to following functionalities: + r""" + .. deprecated:: v1.6 + `TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 + and will be removed in v1.8. + + This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment - Wrapping the datasets and samplers into fault-tolerant components """ - if isinstance(dataloader, CombinedLoader): - # apply `prepare_dataloader` on all the collection of loaders - dataloader.loaders = apply_to_collection( - dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode - ) - # the length need to recomputed across all dataloaders in case of special behavior. - dataloader._apply_cycle_iterator_length() - return dataloader - - # don't do anything if it's not a dataloader - if not isinstance(dataloader, (DataLoader, CycleIterator)): - return dataloader - - cycle_iterator: Optional[CycleIterator] = None - - if isinstance(dataloader, CycleIterator): - cycle_iterator = dataloader - dataloader = dataloader.loader - - if ( - _fault_tolerant_training() # injects components to track the state - or self._requires_distributed_sampler(dataloader) # sets the distributed sampler - or mode == RunningStage.PREDICTING # to track indices for the predictions - or self._accelerator_connector.use_ipu # IPUs use a custom `DataLoader` - ): - sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) - dataloader = _update_dataloader(dataloader, sampler, mode=mode) - - if cycle_iterator is not None: - cycle_iterator.loader = dataloader - return cycle_iterator - - return dataloader - - def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: - if self._requires_distributed_sampler(dataloader): - if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): - raise MisconfigurationException( - "You seem to have configured a sampler in your DataLoader. This will be replaced" - " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" - " distributed training. Either remove the sampler from your DataLoader or set" - " `replace_sampler_ddp=False` if you want to use your custom sampler." - ) - return self._get_distributed_sampler( - dataloader, shuffle, mode=mode, overfit_batches=self.overfit_batches, **self.distributed_sampler_kwargs - ) - - return dataloader.sampler - - @staticmethod - def _get_distributed_sampler( - dataloader: DataLoader, - shuffle: bool, - overfit_batches: Union[int, float], - mode: Optional[RunningStage] = None, - **kwargs: Any, - ) -> DistributedSampler: - """This function is used to created the distributed sampler injected within the user DataLoader.""" - kwargs["shuffle"] = shuffle and not overfit_batches - kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) - cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DistributedSampler - sampler = cls(dataloader.dataset, **kwargs) - return sampler - - def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the train dataloader and initialises required variables (number of batches, when to validate, - etc.). - - Args: - model: The `LightningModule` if calling this outside of the trainer scope. - """ - self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) - - if self.overfit_batches > 0: - self.train_dataloader = self._resolve_overfit_batches(self.train_dataloader) - - # automatically add samplers - self.train_dataloader = apply_to_collection( - self.train_dataloader, DataLoader, self.prepare_dataloader, shuffle=True, mode=RunningStage.TRAINING + rank_zero_deprecation( + "`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8." ) - - # check the workers recursively - apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") - - # add worker_init_fn for correct seeding in worker processes - apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) - - # add collate_fn to collect metadata for fault tolerant training - if _fault_tolerant_training(): - apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) - - # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches - self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) - - module = model or self.lightning_module or self.datamodule - self.num_training_batches = ( - len(self.train_dataloader) - if has_len_all_ranks(self.train_dataloader, self.strategy, module) - else float("inf") - ) - - if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: - self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) - elif self.num_training_batches != float("inf"): - self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) - elif self.limit_train_batches != 1.0: - raise MisconfigurationException( - "When using an IterableDataset for `limit_train_batches`," - " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" - " `num_training_batches` to use." - ) - - # determine when to check validation - # if int passed in, val checks that often - # otherwise, it checks in [0, 1.0] % range of a training epoch - if isinstance(self.val_check_interval, int): - self.val_check_batch = self.val_check_interval - if self.val_check_batch > self.num_training_batches: - raise ValueError( - f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " - f"to the number of the training batches ({self.num_training_batches}). " - "If you want to disable validation set `limit_val_batches` to 0.0 instead." - ) - else: - if not has_len_all_ranks(self.train_dataloader, self.strategy, module): - if self.val_check_interval == 1.0: - self.val_check_batch = float("inf") - else: - raise MisconfigurationException( - "When using an IterableDataset for `train_dataloader`," - " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" - " checking validation every k training batches." - ) - else: - self.val_check_batch = int(self.num_training_batches * self.val_check_interval) - self.val_check_batch = max(1, self.val_check_batch) - - if self.logger and self.num_training_batches < self.log_every_n_steps: - rank_zero_warn( - f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" - f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" - " you want to see logs for the training epoch.", - category=PossibleUserWarning, - ) - - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - self._last_train_dl_reload_epoch = self.current_epoch - - def _reset_eval_dataloader( - self, mode: RunningStage, model: Optional["pl.LightningModule"] = None - ) -> Tuple[List[Union[int, float]], List[DataLoader]]: - """Generic method to reset a dataloader for evaluation. - - Args: - mode: The running stage of the ``Trainer`` - model: The ``LightningModule`` if calling this outside of the trainer scope. - - Returns: - Tuple (num_batches, dataloaders) - """ - assert mode.evaluating or mode == RunningStage.PREDICTING - - # always get the loaders first so we can count how many there are - dataloaders = self.request_dataloader(mode, model=model) - - if not isinstance(dataloaders, list): - dataloaders = [dataloaders] - - if any(dl is None for dl in dataloaders): - rank_zero_warn("One of given dataloaders is None and it will be skipped.") - - for loader in dataloaders: - apply_to_collection( - loader.loaders if isinstance(loader, CombinedLoader) else loader, - DataLoader, - self._check_eval_shuffling, - mode=mode, - ) - - # add samplers - dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] - - # add worker_init_fn for correct seeding in worker processes - apply_to_collection(dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.global_rank) - - loader_num_batches = [] - - # determine number of batches - # datasets could be none, 1 or 2+ - module = model or self.lightning_module or self.datamodule - if len(dataloaders) != 0: - for i, dataloader in enumerate(dataloaders): - orig_num_batches = num_batches = ( - len(dataloader) if has_len_all_ranks(dataloader, self.strategy, module) else float("inf") - ) - self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") - - # percent or num_steps - limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches") - - # limit num batches either as a percent or num steps - if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: - num_batches = min(num_batches, int(limit_eval_batches)) - elif num_batches != float("inf"): - num_batches = int(num_batches * limit_eval_batches) - elif limit_eval_batches != 1.0: - raise MisconfigurationException( - f"When using an IterableDataset for `limit_{mode}_batches`," - f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" - f" specifies `num_{mode.dataloader_prefix}_batches` to use." - ) - - if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): - min_pct = 1.0 / len(dataloader) - raise MisconfigurationException( - f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" - f" {limit_eval_batches} * {orig_num_batches} < 1. Please increase the" - f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" - f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" - ) - - loader_num_batches.append(num_batches) - - return loader_num_batches, dataloaders - - def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the validation dataloader and determines the number of batches. - - Args: - model: The `LightningModule` if called outside of the trainer scope. - """ - source = self._data_connector._val_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("validation_step", pl_module) - if source.is_defined() and has_step: - self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader( - RunningStage.VALIDATING, model=pl_module - ) - - # store epoch of dataloader reset for reload_dataloaders_every_n_epochs - self._last_val_dl_reload_epoch = self.current_epoch - - def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the test dataloader and determines the number of batches. - - Args: - model: The `LightningModule` if called outside of the trainer scope. - """ - source = self._data_connector._test_dataloader_source - pl_module = self.lightning_module or model - has_step = is_overridden("test_step", pl_module) - if source.is_defined() and has_step: - self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader( - RunningStage.TESTING, model=pl_module - ) - - def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the predict dataloader and determines the number of batches. - - Args: - model: The `LightningModule` if called outside of the trainer scope. - """ - source = self._data_connector._predict_dataloader_source - pl_module = self.lightning_module or model - if source.is_defined(): - self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader( - RunningStage.PREDICTING, model=pl_module - ) - - def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets train and val dataloaders if none are attached to the trainer. - - The val dataloader must be initialized before training loop starts, as the training loop - inspects the val dataloader to determine whether to run the evaluation loop. - - Args: - model: The `LightningModule` if called outside of the trainer scope. - """ - if self.train_dataloader is None: - self.reset_train_dataloader(model=model) - if self.val_dataloaders is None: - self.reset_val_dataloader(model=model) + return self._data_connector._prepare_dataloader(dataloader, shuffle, mode) def request_dataloader( self, stage: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Union[DataLoader, List[DataLoader]]: - """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. + r""" + .. deprecated:: v1.6 + `TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 + and will be removed in v1.8. + + Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: The requested dataloader """ - source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source") - - hook = f"{stage.dataloader_prefix}_dataloader" - self._call_lightning_module_hook("on_" + hook, pl_module=model) - with _replace_dataloader_init_method(): - # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as - # attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning - dataloader = source.dataloader() - if isinstance(dataloader, tuple): - dataloader = list(dataloader) - self.strategy.barrier("get_dataloaders") - _validate_fault_tolerant_automatic(dataloader, stage) - return dataloader - - @staticmethod - def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: - all_have_sequential_sampler = True - - def resolve_has_no_sequential_sampler(dataloader: DataLoader): - nonlocal all_have_sequential_sampler - all_have_sequential_sampler = all_have_sequential_sampler & isinstance( - dataloader.sampler, SequentialSampler - ) - - apply_to_collection(dataloader, DataLoader, resolve_has_no_sequential_sampler) - - if not all_have_sequential_sampler: - rank_zero_warn( - "You requested to overfit but enabled training dataloader shuffling." - " We are turning off the training dataloader shuffling for you." - ) - - def replace_sampler(dataloader: DataLoader) -> DataLoader: - return _update_dataloader(dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING) - - dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler) - - return dataloader - - @staticmethod - def _check_eval_shuffling(dataloader, mode): - if ( - hasattr(dataloader, "sampler") - and not isinstance(dataloader.sampler, SequentialSampler) - and not isinstance(dataloader.dataset, IterableDataset) - ): - rank_zero_warn( - f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," - " it is strongly recommended that you turn this off for val/test/predict dataloaders.", - category=PossibleUserWarning, - ) + rank_zero_deprecation( + "`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8." + ) + return self._data_connector._request_dataloader(stage, model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5a6f67c9ea0d3..9931cc31eddcf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -27,6 +27,7 @@ import torch from packaging.version import Version from torch.optim import Optimizer +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, IPUAccelerator @@ -70,6 +71,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( @@ -85,6 +87,7 @@ rank_zero_info, rank_zero_warn, ) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import ( _defaults_from_env_vars, add_argparse_args, @@ -92,7 +95,9 @@ parse_argparser, parse_env_variables, ) +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -119,9 +124,9 @@ class Trainer( - TrainerCallbackHookMixin, - TrainerOptimizersMixin, - TrainerDataLoadingMixin, + TrainerCallbackHookMixin, # TODO: Remove in v1.8 + TrainerOptimizersMixin, # TODO: Remove in v1.8 + TrainerDataLoadingMixin, # TODO: Remove in v1.8 ): # Needed because of LightningOptimizer _lightning_optimizers = None @@ -2364,6 +2369,151 @@ def terminate_on_nan(self, val: bool) -> None: ) self._terminate_on_nan = val # : 212 + def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the train dataloader and initialises required variables (number of batches, when to validate, + etc.). + + Args: + model: The ``LightningModule`` if calling this outside of the trainer scope. + """ + self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model) + + if self.overfit_batches > 0: + self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader) + + # automatically add samplers + self.train_dataloader = apply_to_collection( + self.train_dataloader, + DataLoader, + self._data_connector._prepare_dataloader, + shuffle=True, + mode=RunningStage.TRAINING, + ) + + # check the workers recursively + apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader") + + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) + + # add collate_fn to collect metadata for fault tolerant training + if _fault_tolerant_training(): + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) + + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) + + module = model or self.lightning_module or self.datamodule + self.num_training_batches = ( + len(self.train_dataloader) + if has_len_all_ranks(self.train_dataloader, self.strategy, module) + else float("inf") + ) + + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: + self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) + elif self.num_training_batches != float("inf"): + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) + elif self.limit_train_batches != 1.0: + raise MisconfigurationException( + "When using an IterableDataset for `limit_train_batches`," + " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" + " `num_training_batches` to use." + ) + + # determine when to check validation + # if int passed in, val checks that often + # otherwise, it checks in [0, 1.0] % range of a training epoch + if isinstance(self.val_check_interval, int): + self.val_check_batch = self.val_check_interval + if self.val_check_batch > self.num_training_batches: + raise ValueError( + f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " + f"to the number of the training batches ({self.num_training_batches}). " + "If you want to disable validation set `limit_val_batches` to 0.0 instead." + ) + else: + if not has_len_all_ranks(self.train_dataloader, self.strategy, module): + if self.val_check_interval == 1.0: + self.val_check_batch = float("inf") + else: + raise MisconfigurationException( + "When using an IterableDataset for `train_dataloader`," + " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" + " checking validation every k training batches." + ) + else: + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) + self.val_check_batch = max(1, self.val_check_batch) + + if self.logger and self.num_training_batches < self.log_every_n_steps: + rank_zero_warn( + f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" + f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" + " you want to see logs for the training epoch.", + category=PossibleUserWarning, + ) + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_train_dl_reload_epoch = self.current_epoch + + def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the validation dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._val_dataloader_source + pl_module = self.lightning_module or model + has_step = is_overridden("validation_step", pl_module) + if source.is_defined() and has_step: + self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.VALIDATING, model=pl_module + ) + + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_val_dl_reload_epoch = self.current_epoch + + def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the test dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._test_dataloader_source + pl_module = self.lightning_module or model + has_step = is_overridden("test_step", pl_module) + if source.is_defined() and has_step: + self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.TESTING, model=pl_module + ) + + def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets the predict dataloader and determines the number of batches. + + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + source = self._data_connector._predict_dataloader_source + pl_module = self.lightning_module or model + if source.is_defined(): + self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader( + RunningStage.PREDICTING, model=pl_module + ) + + def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: + """Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + Args: + model: The ``LightningModule`` if called outside of the trainer scope. + """ + if self.train_dataloader is None: + self.reset_train_dataloader(model=model) + if self.val_dataloaders is None: + self.reset_val_dataloader(model=model) + def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 8eaebc0b51684..52c57a1f70247 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -31,11 +31,12 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringDataModule, BoringModel from tests.helpers.runif import RunIf from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator @@ -271,6 +272,22 @@ def test_v1_8_0_deprecated_training_type_plugin_property(): trainer.training_type_plugin +def test_v1_8_0_deprecate_trainer_data_loading_mixin(): + trainer = Trainer(max_epochs=1) + model = BoringModel() + dm = BoringDataModule() + trainer.fit(model, datamodule=dm) + + with pytest.deprecated_call( + match=r"`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8.", + ): + trainer.prepare_dataloader(dataloader=model.train_dataloader, shuffle=False) + with pytest.deprecated_call( + match=r"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8.", + ): + trainer.request_dataloader(stage=RunningStage.TRAINING) + + def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys(): from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys diff --git a/tests/trainer/flags/test_limit_batches.py b/tests/trainer/flags/test_limit_batches.py index 99c3f6a4daf42..92febaeb7f83b 100644 --- a/tests/trainer/flags/test_limit_batches.py +++ b/tests/trainer/flags/test_limit_batches.py @@ -62,7 +62,7 @@ def test_eval_limit_batches(stage, mode, limit_batches): trainer = Trainer(**{limit_eval_batches: limit_batches}) model.trainer = trainer trainer._data_connector.attach_dataloaders(model) - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model) + loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model) expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches assert loader_num_batches[0] == expected_batches assert len(dataloaders[0]) == len(eval_loader) diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index 9e9bbe9cb06e2..7178fbd9c065e 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -80,7 +80,7 @@ def test_overfit_batch_limits_eval(stage, mode, overfit_batches): model.trainer = trainer trainer._data_connector.attach_datamodule(model, datamodule=dm) - loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model) + loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model) if stage == RunningStage.VALIDATING: assert loader_num_batches[0] == 0 else: diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 95b9f9061b366..d8b43b1d4d1c4 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -248,17 +248,17 @@ def __init__( class CustomDummyObj: sampler = None - result = trainer.prepare_dataloader(CustomDummyObj(), shuffle=True) + result = trainer._data_connector._prepare_dataloader(CustomDummyObj(), shuffle=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" dataset = list(range(10)) - result = trainer.prepare_dataloader(CustomDataLoader(dataset), shuffle=True) + result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset), shuffle=True) assert isinstance(result, DataLoader) assert isinstance(result, CustomDataLoader) assert result.dummy_kwarg is None # Shuffled DataLoader should also work - result = trainer.prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True) + result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True) assert isinstance(result, DataLoader) assert isinstance(result, CustomDataLoader) assert result.dummy_kwarg is None @@ -269,7 +269,7 @@ class CustomSampler(Sampler): # Should raise an error if existing sampler is being replaced dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): - trainer.prepare_dataloader(dataloader, shuffle=True) + trainer._data_connector._prepare_dataloader(dataloader, shuffle=True) class LoaderTestModel(BoringModel): @@ -351,7 +351,7 @@ def test_error_raised_with_float_limited_eval_batches(): MisconfigurationException, match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", ): - trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) + trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) @pytest.mark.parametrize( @@ -375,4 +375,4 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl): model = BoringModel() trainer._data_connector.attach_data(model, val_dataloaders=val_dl) with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"): - trainer._reset_eval_dataloader(RunningStage.VALIDATING, model) + trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 47013c9e3f387..7de68b60cac8b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -563,7 +563,7 @@ def train_dataloader(self): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) @pytest.mark.parametrize("stage", ("train", "test", "val")) -@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4) +@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): """Test that error is raised if dataloader with only a few workers is used.""" @@ -593,7 +593,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): @RunIf(skip_windows=True) @pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) @pytest.mark.parametrize("stage", ("train", "test", "val")) -@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4) +@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4) def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): """Test that error is raised if dataloader with only a few workers is used.""" diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 436a82c877c4f..4ba30dd7da1c0 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -350,7 +350,7 @@ def __init__(self, data_source, name) -> None: with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(use_fault_tolerant))}): trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp, strategy="ddp", gpus=2) - dataloader = trainer.prepare_dataloader(dataloader, shuffle=True) + dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=True) _count = 0 _has_fastforward_sampler = False @@ -390,7 +390,7 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, t dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) - dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 for a_length in [6, 8, 10]: @@ -404,7 +404,7 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, t length = max(a_length, 8) assert len(dataloader) == length - dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length if replace_sampler_ddp: last_batch = list(dataloader)[-1] @@ -429,6 +429,6 @@ def __iter__(self): ) assert get_len(dataloader) == float("inf") assert len(dataloader.loaders["b"].loader) == 8 - dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) + dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 assert get_len(dataloader) == float("inf") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e33bc91621a2b..e3e4f00c6e8d3 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -666,7 +666,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset -@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic") +@mock.patch("pytorch_lightning.trainer.connectors.data_connector._validate_fault_tolerant_automatic") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(_, tmpdir, use_fault_tolerant): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" @@ -895,7 +895,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_ return model.seen_batches, model.parameters() -@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic") +@mock.patch("pytorch_lightning.trainer.connectors.data_connector._validate_fault_tolerant_automatic") @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize( "dataset_classes",