Skip to content

typing trainer.connectors.checkpoint connector #9419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -393,7 +393,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
with self.training_type_plugin.model_sharded_context():
yield

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
Expand Down Expand Up @@ -664,7 +664,7 @@ def deepspeed_engine(self):
def _multi_device(self) -> bool:
return self.num_processes > 1 or self.num_nodes > 1

def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict, filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -62,10 +63,10 @@ def pre_dispatch(self) -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

def save(self, state_dict: Dict, path: str) -> None:
def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down Expand Up @@ -207,7 +207,7 @@ def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", resul
self.mp_queue.put(results)
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue

def save(self, state_dict: Dict, path: str) -> None:
def save(self, state_dict: Dict, path: _PATH) -> None:
xm.save(state_dict, path)

def broadcast(self, obj: object, src: int = 0) -> object:
Expand Down Expand Up @@ -303,7 +303,7 @@ def _pod_progress_bar_force_stdout(self) -> None:
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union

import torch
Expand All @@ -26,7 +25,7 @@
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT

TBroadcast = TypeVar("T")

Expand Down Expand Up @@ -152,7 +151,7 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
"""
return self._results

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)

Expand Down Expand Up @@ -259,7 +258,7 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
model = self.lightning_module
return model.state_dict()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
24 changes: 13 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,30 @@

import os
import re
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional

import torch
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
from omegaconf import Container


class CheckpointConnector:
def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] = None):
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self._loaded_checkpoint = {}
self._loaded_checkpoint: Dict[str, Any] = {}

@property
def hpc_resume_path(self) -> Optional[str]:
Expand Down Expand Up @@ -77,7 +78,7 @@ def resume_end(self) -> None:
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
state-restore, in this priority:

Expand Down Expand Up @@ -140,7 +141,7 @@ def restore_model(self) -> None:
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
Expand Down Expand Up @@ -192,6 +193,7 @@ def restore_loops(self) -> None:
# crash if max_epochs is lower then the current epoch from the checkpoint
if (
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
and self.trainer.max_epochs is not None
and self.trainer.current_epoch > self.trainer.max_epochs
):
raise MisconfigurationException(
Expand Down Expand Up @@ -268,7 +270,7 @@ def restore_lr_schedulers(self) -> None:
# PRIVATE OPS
# ----------------------------------

def hpc_save(self, folderpath: str, logger):
def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
Expand Down Expand Up @@ -382,7 +384,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, checkpoint_path: str) -> None:
def hpc_load(self, checkpoint_path: _PATH) -> None:
"""Attempts to restore the full training and model state from a HPC checkpoint file.

.. deprecated:: v1.4 Will be removed in v1.6. Use :meth:`restore` instead.
Expand All @@ -393,7 +395,7 @@ def hpc_load(self, checkpoint_path: str) -> None:
)
self.restore(checkpoint_path)

def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = "ckpt_") -> Optional[int]:
def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

Args:
Expand Down Expand Up @@ -423,14 +425,14 @@ def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str =

return max(ckpt_vs)

def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""

max_suffix = self.max_ckpt_version_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"

def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
self.accelerator.optimizers = new_optims

@property
def lr_schedulers(self) -> Optional[list]:
def lr_schedulers(self) -> List:
return self.accelerator.lr_schedulers

@lr_schedulers.setter
Expand Down