Skip to content

[WIP] Fix mypy typing for utilities.auto_restart #8783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ module = [
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.debugging",
Expand Down
35 changes: 21 additions & 14 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections.abc import Mapping
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union

from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
Expand All @@ -34,7 +34,7 @@ class FastForwardSampler(Sampler):
samples seen in the last iterations (for the current worker).
"""

def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
def __init__(self, sampler: Union[Sampler, Generator, Iterator], attr_name: Optional[str] = None) -> None:
super().__init__(data_source=None)
self._sampler = sampler
self.restarting: bool = False
Expand Down Expand Up @@ -116,12 +116,13 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None

return current_iteration

def _load_cached_state(self):
def _load_cached_state(self) -> None:
if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict:
return
self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"]
# delete cached state, prevent reloading every time iter() is called
self._cached_state_dict = None
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure this is not necessary :/

Copy link
Contributor

@awaelchli awaelchli Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we should we put warn_no_return = False for mypy? @stancld @carmocca

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I read a PEP recommendation suggesting if there's any return statement within a function, it should be obeyed for the whole function definition. But yes, my suggestion might be redundant here.



class CaptureIterableDataset(IterableDataset):
Expand All @@ -143,8 +144,10 @@ def __init__(self, dataset: IterableDataset) -> None:
def sampler(self) -> Sampler:
return self.dataset.sampler

def state_dict(self) -> Dict[str, Any]:
return {k: v.state_dict() for k, v in self.samplers.items()}
def state_dict(self) -> Optional[Dict[str, Any]]:
if isinstance(self.samplers, dict):
return {k: v.state_dict() for k, v in self.samplers.items()}
return {}

def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
self._state_dict = deepcopy(state_dict)
Expand Down Expand Up @@ -232,7 +235,7 @@ def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> N
iterator._sampler_state_dict_cache = sampler_state_dict

@staticmethod
def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List):
def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List) -> Any:
"""
This function is used to remove the sampler state dict from provided data batch.
The custom data has this format:
Expand All @@ -255,7 +258,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List):
will extract the current iteration as part of the metadata returned by a custom batch.
"""

def _sanitize(data: Mapping):
def _sanitize(data: Mapping) -> List[Tuple[Any, Any]]:
out = []
for k, v in data.items():
if k == AutoRestartBatchKeys.PL_SAMPLERS:
Expand All @@ -267,11 +270,11 @@ def _sanitize(data: Mapping):
return apply_to_collection(data, Mapping, _sanitize)

@staticmethod
def extract_samplers_state_dict_from_batch(batch) -> List[Dict[int, Any]]:
def extract_samplers_state_dict_from_batch(batch: Any) -> Tuple[Any, List[Dict[int, Any]]]:
"""
This function is used to convert a batch into a state_dict
"""
samplers_state_dict = []
samplers_state_dict: List[Dict[int, Any]] = []

batch = CaptureIterableDataset._sanitize_batch_from_sampler_state(batch, samplers_state_dict)

Expand All @@ -287,6 +290,7 @@ def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardS

if isinstance(dataloader.batch_sampler, FastForwardSampler):
return dataloader.batch_sampler
return None


def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> Iterator:
Expand Down Expand Up @@ -328,22 +332,25 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str

def _dataloader_to_state_dict(
dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None
) -> List[Dict[str, Any]]:
) -> Dict[Union[int, str], Union[Dict[str, int], Optional[int]]]:
"""
Convert a dataloader to its associated state dict
"""
out = {}
out: Dict[Union[int, str], Union[Dict[str, int], Optional[int]]] = {}
if iterator is not None:
out.update(_find_current_worker(iterator))
for iterator_k, iterator_v in _find_current_worker(iterator).items():
out[iterator_k] = iterator_v

if not isinstance(dataloader.dataset, CaptureIterableDataset):
fast_forward_sampler = _find_fast_forward_samplers(dataloader)
if fast_forward_sampler is not None:
out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed))
samplers_state_dict = fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed)
for sampler_k, sampler_v in samplers_state_dict.items():
out[sampler_k] = sampler_v
return out


def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader:
def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> DataLoader:
"""
Reload ``DataLoader`` fast-forward sampler state dict.
"""
Expand Down