diff --git a/pyproject.toml b/pyproject.toml index d07f19ef10986..068b929f00d76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ module = [ "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.argparse", + "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.debugging", diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 464823038ae2e..4bb9e1bc4c175 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -14,7 +14,7 @@ from collections.abc import Mapping from copy import deepcopy -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset @@ -34,7 +34,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Union[Sampler, Generator, Iterator], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False @@ -116,12 +116,13 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None return current_iteration - def _load_cached_state(self): + def _load_cached_state(self) -> None: if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict: return self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"] # delete cached state, prevent reloading every time iter() is called self._cached_state_dict = None + return class CaptureIterableDataset(IterableDataset): @@ -143,8 +144,10 @@ def __init__(self, dataset: IterableDataset) -> None: def sampler(self) -> Sampler: return self.dataset.sampler - def state_dict(self) -> Dict[str, Any]: - return {k: v.state_dict() for k, v in self.samplers.items()} + def state_dict(self) -> Optional[Dict[str, Any]]: + if isinstance(self.samplers, dict): + return {k: v.state_dict() for k, v in self.samplers.items()} + return {} def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self._state_dict = deepcopy(state_dict) @@ -232,7 +235,7 @@ def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> N iterator._sampler_state_dict_cache = sampler_state_dict @staticmethod - def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): + def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List) -> Any: """ This function is used to remove the sampler state dict from provided data batch. The custom data has this format: @@ -255,7 +258,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): will extract the current iteration as part of the metadata returned by a custom batch. """ - def _sanitize(data: Mapping): + def _sanitize(data: Mapping) -> List[Tuple[Any, Any]]: out = [] for k, v in data.items(): if k == AutoRestartBatchKeys.PL_SAMPLERS: @@ -267,11 +270,11 @@ def _sanitize(data: Mapping): return apply_to_collection(data, Mapping, _sanitize) @staticmethod - def extract_samplers_state_dict_from_batch(batch) -> List[Dict[int, Any]]: + def extract_samplers_state_dict_from_batch(batch: Any) -> Tuple[Any, List[Dict[int, Any]]]: """ This function is used to convert a batch into a state_dict """ - samplers_state_dict = [] + samplers_state_dict: List[Dict[int, Any]] = [] batch = CaptureIterableDataset._sanitize_batch_from_sampler_state(batch, samplers_state_dict) @@ -287,6 +290,7 @@ def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardS if isinstance(dataloader.batch_sampler, FastForwardSampler): return dataloader.batch_sampler + return None def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> Iterator: @@ -328,22 +332,25 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str def _dataloader_to_state_dict( dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None -) -> List[Dict[str, Any]]: +) -> Dict[Union[int, str], Union[Dict[str, int], Optional[int]]]: """ Convert a dataloader to its associated state dict """ - out = {} + out: Dict[Union[int, str], Union[Dict[str, int], Optional[int]]] = {} if iterator is not None: - out.update(_find_current_worker(iterator)) + for iterator_k, iterator_v in _find_current_worker(iterator).items(): + out[iterator_k] = iterator_v if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed)) + samplers_state_dict = fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed) + for sampler_k, sampler_v in samplers_state_dict.items(): + out[sampler_k] = sampler_v return out -def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader: +def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> DataLoader: """ Reload ``DataLoader`` fast-forward sampler state dict. """