From 056155145f2bc3a1001f7cdeaf990a15dc7d06f9 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 21:53:54 +0200 Subject: [PATCH 01/10] Fix some mypy issues --- pyproject.toml | 1 + pytorch_lightning/utilities/auto_restart.py | 26 ++++++++++++--------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0da4b2cfe10b4..a2e7deb38ecd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ module = [ "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.argparse", + "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.debugging", diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 464823038ae2e..f61c789df5a97 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -14,7 +14,7 @@ from collections.abc import Mapping from copy import deepcopy -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset @@ -34,7 +34,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Union[Sampler, Generator, Iterator], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False @@ -116,12 +116,13 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None return current_iteration - def _load_cached_state(self): + def _load_cached_state(self) -> None: if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict: return self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"] # delete cached state, prevent reloading every time iter() is called self._cached_state_dict = None + return class CaptureIterableDataset(IterableDataset): @@ -143,8 +144,10 @@ def __init__(self, dataset: IterableDataset) -> None: def sampler(self) -> Sampler: return self.dataset.sampler - def state_dict(self) -> Dict[str, Any]: - return {k: v.state_dict() for k, v in self.samplers.items()} + def state_dict(self) -> Optional[Dict[str, Any]]: + if isinstance(self.samplers, dict): + return {k: v.state_dict() for k, v in self.samplers.items()} + return None def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self._state_dict = deepcopy(state_dict) @@ -232,7 +235,7 @@ def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> N iterator._sampler_state_dict_cache = sampler_state_dict @staticmethod - def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): + def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List) -> Any: """ This function is used to remove the sampler state dict from provided data batch. The custom data has this format: @@ -255,7 +258,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): will extract the current iteration as part of the metadata returned by a custom batch. """ - def _sanitize(data: Mapping): + def _sanitize(data: Mapping) -> List[Tuple[Any, Any]]: out = [] for k, v in data.items(): if k == AutoRestartBatchKeys.PL_SAMPLERS: @@ -267,11 +270,11 @@ def _sanitize(data: Mapping): return apply_to_collection(data, Mapping, _sanitize) @staticmethod - def extract_samplers_state_dict_from_batch(batch) -> List[Dict[int, Any]]: + def extract_samplers_state_dict_from_batch(batch: Any) -> Tuple[Any, List[Dict[int, Any]]]: """ This function is used to convert a batch into a state_dict """ - samplers_state_dict = [] + samplers_state_dict: List[Dict[int, Any]] = [] batch = CaptureIterableDataset._sanitize_batch_from_sampler_state(batch, samplers_state_dict) @@ -287,6 +290,7 @@ def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardS if isinstance(dataloader.batch_sampler, FastForwardSampler): return dataloader.batch_sampler + return None def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> Iterator: @@ -328,7 +332,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str def _dataloader_to_state_dict( dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None -) -> List[Dict[str, Any]]: +) -> Dict[str, Any]: """ Convert a dataloader to its associated state dict """ @@ -343,7 +347,7 @@ def _dataloader_to_state_dict( return out -def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader: +def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> DataLoader: """ Reload ``DataLoader`` fast-forward sampler state dict. """ From 8cdf45f493bdb6234ef8aa30163676dc670556c9 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 22:27:07 +0200 Subject: [PATCH 02/10] Replace Dict with Mapping to enable passing into dict.update --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f61c789df5a97..1fc55f5953ee0 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -85,7 +85,7 @@ def __iter__(self) -> Iterator[Any]: def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Mapping[int, Dict[str, int]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} From fe19c79cb8f470141d61b1ef7b76263ab69722fd Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 22:39:38 +0200 Subject: [PATCH 03/10] Revert the last change --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 1fc55f5953ee0..f61c789df5a97 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -85,7 +85,7 @@ def __iter__(self) -> Iterator[Any]: def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Mapping[int, Dict[str, int]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} From 475cf66db2a3d9d60c5c280862ead5ccc2162da4 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 22:46:55 +0200 Subject: [PATCH 04/10] Rewrite dict's update method with for loop to fix typing --- pytorch_lightning/utilities/auto_restart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f61c789df5a97..c822c841f9ad7 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -343,7 +343,8 @@ def _dataloader_to_state_dict( if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed)) + for k, v in fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).item(): + out[k] = v return out From 0e2e120fc58950c2885fc96198891c4364232625 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 22:48:41 +0200 Subject: [PATCH 05/10] Fix a typo: item() -> items() --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c822c841f9ad7..f7a08fd4d8ac4 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -343,7 +343,7 @@ def _dataloader_to_state_dict( if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - for k, v in fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).item(): + for k, v in fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items(): out[k] = v return out From 303c20baca31d2dc97c1869de2c11391f5238e0c Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 6 Aug 2021 23:00:17 +0200 Subject: [PATCH 06/10] Substantially rewrite typing for _dataloader_to_state_dict --- pytorch_lightning/utilities/auto_restart.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f7a08fd4d8ac4..303b89e32bac7 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -332,19 +332,22 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str def _dataloader_to_state_dict( dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None -) -> Dict[str, Any]: +) -> Dict[Union[int, str], Union[Dict[str, int], Optional[int]]]: """ Convert a dataloader to its associated state dict """ - out = {} + out: Dict[Union[int, str], Union[Dict[str, int], Optional[int]]] = {} if iterator is not None: - out.update(_find_current_worker(iterator)) + for iterator_k, iterator_v in _find_current_worker(iterator).items(): + out[iterator_k] = iterator_v if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - for k, v in fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items(): - out[k] = v + for sampler_k, sampler_v in ( + fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items() + ): + out[sampler_k] = sampler_v return out From db28124880614fc9282ae5a4418f23ba941ce4f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Aug 2021 21:01:24 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/auto_restart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 303b89e32bac7..11af98a650910 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -344,9 +344,9 @@ def _dataloader_to_state_dict( if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - for sampler_k, sampler_v in ( - fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items() - ): + for sampler_k, sampler_v in fast_forward_sampler.state_dict( + num_batches_processed=num_batches_processed + ).items(): out[sampler_k] = sampler_v return out From c5f6b836ac738055141648a1a74c3d9a7eea3f6e Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 9 Aug 2021 20:23:36 +0200 Subject: [PATCH 08/10] Apply the tchaton's suggestion Co-authored-by: thomas chaton --- pytorch_lightning/utilities/auto_restart.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 11af98a650910..6315844ff64f8 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -344,9 +344,8 @@ def _dataloader_to_state_dict( if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - for sampler_k, sampler_v in fast_forward_sampler.state_dict( - num_batches_processed=num_batches_processed - ).items(): + samplers_state_dict = fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items() + for sampler_k, sampler_v in samplers_state_dict.items(): out[sampler_k] = sampler_v return out From d35be356e6bebf6f4bd38d354ef130d32734610d Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 9 Aug 2021 20:24:25 +0200 Subject: [PATCH 09/10] Apply the carmocca's suggestion --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 6315844ff64f8..46971699c3055 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -147,7 +147,7 @@ def sampler(self) -> Sampler: def state_dict(self) -> Optional[Dict[str, Any]]: if isinstance(self.samplers, dict): return {k: v.state_dict() for k, v in self.samplers.items()} - return None + return {} def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self._state_dict = deepcopy(state_dict) From 60f12e3a8cc1939377049ed6ea803a00fd0d3173 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 9 Aug 2021 22:23:28 +0200 Subject: [PATCH 10/10] Remove unintentional dict's items() call --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 46971699c3055..4bb9e1bc4c175 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -344,7 +344,7 @@ def _dataloader_to_state_dict( if not isinstance(dataloader.dataset, CaptureIterableDataset): fast_forward_sampler = _find_fast_forward_samplers(dataloader) if fast_forward_sampler is not None: - samplers_state_dict = fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed).items() + samplers_state_dict = fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed) for sampler_k, sampler_v in samplers_state_dict.items(): out[sampler_k] = sampler_v return out