From a93e452d1604967974382c07c11c57380bd11331 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 19 Jul 2021 16:51:47 +0100 Subject: [PATCH 01/37] poc API --- .../plugins/checkpoint/__init__.py | 0 .../plugins/checkpoint/checkpoint.py | 38 +++++++++++++++++++ pytorch_lightning/plugins/checkpoint/torch.py | 33 ++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 pytorch_lightning/plugins/checkpoint/__init__.py create mode 100644 pytorch_lightning/plugins/checkpoint/checkpoint.py create mode 100644 pytorch_lightning/plugins/checkpoint/torch.py diff --git a/pytorch_lightning/plugins/checkpoint/__init__.py b/pytorch_lightning/plugins/checkpoint/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py new file mode 100644 index 0000000000000..f4f18527c8afd --- /dev/null +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -0,0 +1,38 @@ +from abc import ABC +from pathlib import Path +from typing import Any, Dict, Mapping, Union + + +class CheckpointPlugin(ABC): + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + """ + Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. + Args: + checkpoint_path: Path to checkpoint + + Returns: The loaded checkpoint. + """ + pass + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + """ + Given the loaded checkpoint file, loads the state dict into the model. + Args: + checkpoint: The loaded checkpoint file. + """ + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + """ + Given the loaded checkpoint file, loads the optimizer state dicts into optimizers. + Args: + checkpoint: The loaded checkpoint file. + """ diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py new file mode 100644 index 0000000000000..d8f38b786f416 --- /dev/null +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -0,0 +1,33 @@ +from abc import ABC +from pathlib import Path +from typing import Any, Dict, Mapping, Union + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import load as pl_load + + +class CheckpointPlugin(ABC): + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + # dump states as a checkpoint dictionary object + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, filepath) + except AttributeError as err: + key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY + checkpoint.pop(key, None) + rank_zero_warn(f'Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}') + atomic_save(checkpoint, filepath) + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + optimizer_states = checkpoint["optimizer_states"] + for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) From e7d2b66824d5a2763b42fc08c3a31e660df95616 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 5 Aug 2021 11:19:55 +0100 Subject: [PATCH 02/37] Fix up the API, unsure on connection --- .../plugins/checkpoint/checkpoint.py | 23 ++++ .../plugins/checkpoint/deepspeed.py | 120 ++++++++++++++++++ pytorch_lightning/plugins/checkpoint/torch.py | 11 +- .../plugins/training_type/deepspeed.py | 118 +++-------------- .../plugins/training_type/fully_sharded.py | 6 - .../training_type/training_type_plugin.py | 71 +++++------ 6 files changed, 204 insertions(+), 145 deletions(-) create mode 100644 pytorch_lightning/plugins/checkpoint/deepspeed.py diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index f4f18527c8afd..aa71cf6a001b4 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -2,9 +2,32 @@ from pathlib import Path from typing import Any, Dict, Mapping, Union +from torch.nn import Module + +from pytorch_lightning import LightningModule + class CheckpointPlugin(ABC): + def __init__(self): + self._training_type_plugin = None + + @property + def training_type_plugin(self) -> 'TrainingTypePlugin': + return self._training_type_plugin + + @training_type_plugin.setter + def training_type_plugin(self, plugin) -> None: + self._training_type_plugin = plugin + + @property + def lightning_module(self) -> LightningModule: + return self.training_type_plugin.lightning_module + + @property + def model(self) -> Module: + return self.training_type_plugin.model + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/checkpoint/deepspeed.py b/pytorch_lightning/plugins/checkpoint/deepspeed.py new file mode 100644 index 0000000000000..dfa0486301765 --- /dev/null +++ b/pytorch_lightning/plugins/checkpoint/deepspeed.py @@ -0,0 +1,120 @@ +from pathlib import Path +from typing import Any, Dict, Mapping, Union, Optional + +import torch +from torch import Tensor + +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() +if _DEEPSPEED_AVAILABLE: + import deepspeed + + +class DeepSpeedCheckpointPlugin(CheckpointPlugin): + + def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: The checkpoint state dictionary + filepath: write-target file's path + """ + if self.training_type_plugin.zero_stage_3 and self.training_type_plugin._multi_device and self.training_type_plugin.is_global_zero: + warning_cache.warn( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + "If a single file is required after training, " + "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" + "deepspeed-zero-stage-3-single-file for instructions." + ) + # Use deepspeed's internal checkpointing function to handle partitioned weights across processes + # dump states as a checkpoint dictionary object + _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] + checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} + self.model.save_checkpoint(filepath, client_state=checkpoint) + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: + # Broadcast to ensure we load from the rank 0 checkpoint + # This doesn't have to be the case when using deepspeed sharded checkpointing + checkpoint_path = self.training_type_plugin.broadcast(checkpoint_path) + return super().load_checkpoint_file(checkpoint_path) + + # Rely on deepspeed to load the checkpoint and necessary information + from pytorch_lightning.trainer.states import TrainerFn + + is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + _, client_state = self.model.load_checkpoint( + checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting + ) + if client_state is None: + raise MisconfigurationException( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " + "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." + ) + return client_state + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` + if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: + self.training_type_plugin.model_to_device() + self._restore_zero_state(checkpoint) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` + pass + + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + model = self.lightning_module + return model.state_dict() + + def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: + """ + Overrides the normal load_state_dict behaviour in PyTorch to ensure + we gather parameters that may be sharded across processes before loading + the state dictionary when using ZeRO stage 3. + This is then automatically synced across processes. + + Args: + ckpt: The ckpt file. + """ + + def load(module: torch.nn.Module, prefix=""): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + state_dict = ckpt["state_dict"] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.training_type_plugin.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=True, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs, + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self.lightning_module, prefix="") diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index d8f38b786f416..b1d4af686a211 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,14 +1,16 @@ -from abc import ABC from pathlib import Path from typing import Any, Dict, Mapping, Union +from torch import Tensor + import pytorch_lightning as pl +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -class CheckpointPlugin(ABC): +class TorchCheckpointPlugin(CheckpointPlugin): def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object @@ -31,3 +33,8 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: optimizer_states = checkpoint["optimizer_states"] for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) + + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + model = self.lightning_module + return model.state_dict() diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 515be43caad31..b96a53bee60b7 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,6 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.checkpoint.deepspeed import DeepSpeedCheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config @@ -113,6 +114,7 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_plugin: Optional[DeepSpeedCheckpointPlugin] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -276,6 +278,10 @@ def __init__( super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment ) + + self.checkpoint_plugin = DeepSpeedPlugin() if checkpoint_plugin is None else checkpoint_plugin + self.checkpoint_plugin.training_type_plugin = self + self.config = self._load_config(config) if self.config is None: # User has not overridden config, set defaults @@ -668,51 +674,9 @@ def deepspeed_engine(self): return self.model @property - def _multi_device(self) -> bool: + def multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 - def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: - """Save model/training states as a checkpoint file through state-dump and file-write. - - Args: - checkpoint: The checkpoint state dictionary - filepath: write-target file's path - """ - if self.zero_stage_3 and self._multi_device and self.is_global_zero: - warning_cache.warn( - "When saving the DeepSpeed Stage 3 checkpoint, " - "each worker will save a shard of the checkpoint within a directory. " - "If a single file is required after training, " - "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" - "deepspeed-zero-stage-3-single-file for instructions." - ) - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes - # dump states as a checkpoint dictionary object - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) - - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: - if self.load_full_weights and self.zero_stage_3: - # Broadcast to ensure we load from the rank 0 checkpoint - # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = self.broadcast(checkpoint_path) - return super().load_checkpoint_file(checkpoint_path) - - # Rely on deepspeed to load the checkpoint and necessary information - from pytorch_lightning.trainer.states import TrainerFn - - is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING - _, client_state = self.deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting - ) - if client_state is None: - raise MisconfigurationException( - "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " - "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." - ) - return client_state - @property def lightning_restore_optimizer_and_schedulers(self) -> bool: # managed by DeepSpeed @@ -724,62 +688,6 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: ) return False - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` - if self.load_full_weights and self.zero_stage_3: - self.model_to_device() - self._restore_zero_state(checkpoint) - - def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: - """ - Overrides the normal load_state_dict behaviour in PyTorch to ensure - we gather parameters that may be sharded across processes before loading - the state dictionary when using ZeRO stage 3. - This is then automatically synced across processes. - - Args: - ckpt: The ckpt file. - """ - - def load(module: torch.nn.Module, prefix=""): - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - state_dict = ckpt["state_dict"] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if self.is_global_zero: - module._load_from_state_dict( - state_dict=state_dict, - prefix=prefix, - local_metadata=local_metadata, - strict=True, - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - error_msgs=error_msgs, - ) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(self.lightning_module, prefix="") - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` - pass - def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: if self._original_accumulate_grad_batches is None: return super().update_global_step(total_batch_idx, current_global_step) @@ -818,3 +726,15 @@ def register_plugins(cls, plugin_registry: Dict) -> None: offload_params_device="nvme", offload_optimizer_device="nvme", ) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.checkpoint_plugin.load_model_state_dict(checkpoint) + + def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: + return self.checkpoint_plugin.save_checkpoint(checkpoint) + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + return self.checkpoint_plugin.load_optimizer_state_dict(checkpoint) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 1acac25e96db4..bbf6143ce6b66 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -169,12 +169,6 @@ def model_to_device(self) -> None: # ensure we update the device type in the lightning module self.lightning_module.to(self.root_device) - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - # Currently it is same as default TrainingTypePlugin, i.e. return - # the full state dict for FSDP, in the future, we will provide sharded - # state dict. - return super().lightning_module_state_dict() - @property def setup_optimizers_in_pre_dispatch(self) -> bool: # Setup optimizers after the Fully Sharded Model has been made diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cdff37fd9bcb2..cff3c55063bde 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,9 +25,8 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -42,6 +41,15 @@ def __init__(self) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True + self._checkpoint_plugin = TorchCheckpointPlugin() + + @property + def checkpoint_plugin(self) -> CheckpointPlugin: + return self._checkpoint_plugin + + @checkpoint_plugin.setter + def checkpoint_plugin(self, plugin) -> None: + self._checkpoint_plugin = plugin def connect(self, model: Module) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -145,17 +153,6 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ return self._results - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - self.lightning_module.load_state_dict(checkpoint["state_dict"]) - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - optimizer_states = checkpoint["optimizer_states"] - for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - def start_training(self, trainer: "pl.Trainer") -> None: # double dispatch to initiate the training loop self._results = trainer.run_stage() @@ -268,30 +265,6 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> """ return current_global_step + 1 - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """Returns model state.""" - model = self.lightning_module - return model.state_dict() - - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: - """Save model/training states as a checkpoint file through state-dump and file-write. - - Args: - checkpoint: dict containing model and trainer state - filepath: write-target file's path - """ - # dump states as a checkpoint dictionary object - checkpoint = self.on_save(checkpoint) - if self.is_global_zero: - try: - # write the checkpoint dictionary on the file - atomic_save(checkpoint, filepath) - except AttributeError as err: - key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY - checkpoint.pop(key, None) - rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") - atomic_save(checkpoint, filepath) - @contextlib.contextmanager def model_sharded_context(self) -> Generator: """ @@ -370,3 +343,25 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) Called in the training loop before anything happens for that batch. """ pass + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.checkpoint_plugin.load_model_state_dict(checkpoint) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.checkpoint_plugin.load_optimizer_state_dict(checkpoint) + + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + return self.checkpoint_plugin.lightning_module_state_dict() + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) From b41e7949a4b4fd3f2ab41a1a0e0706a530a03f4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Aug 2021 10:23:31 +0000 Subject: [PATCH 03/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 3 +-- pytorch_lightning/plugins/checkpoint/deepspeed.py | 9 ++++++--- pytorch_lightning/plugins/checkpoint/torch.py | 3 +-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index aa71cf6a001b4..33d2a6cb9bf2b 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -8,12 +8,11 @@ class CheckpointPlugin(ABC): - def __init__(self): self._training_type_plugin = None @property - def training_type_plugin(self) -> 'TrainingTypePlugin': + def training_type_plugin(self) -> "TrainingTypePlugin": return self._training_type_plugin @training_type_plugin.setter diff --git a/pytorch_lightning/plugins/checkpoint/deepspeed.py b/pytorch_lightning/plugins/checkpoint/deepspeed.py index dfa0486301765..1eb93b56673ba 100644 --- a/pytorch_lightning/plugins/checkpoint/deepspeed.py +++ b/pytorch_lightning/plugins/checkpoint/deepspeed.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Mapping, Union, Optional +from typing import Any, Dict, Mapping, Optional, Union import torch from torch import Tensor @@ -15,7 +15,6 @@ class DeepSpeedCheckpointPlugin(CheckpointPlugin): - def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -23,7 +22,11 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: checkpoint: The checkpoint state dictionary filepath: write-target file's path """ - if self.training_type_plugin.zero_stage_3 and self.training_type_plugin._multi_device and self.training_type_plugin.is_global_zero: + if ( + self.training_type_plugin.zero_stage_3 + and self.training_type_plugin._multi_device + and self.training_type_plugin.is_global_zero + ): warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory. " diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index b1d4af686a211..6d3d03e22f585 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -11,7 +11,6 @@ class TorchCheckpointPlugin(CheckpointPlugin): - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object try: @@ -20,7 +19,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: except AttributeError as err: key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) - rank_zero_warn(f'Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}') + rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") atomic_save(checkpoint, filepath) def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: From 91619801c6dd3abc4d522434c1b5b0f982613010 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 5 Aug 2021 13:50:48 +0100 Subject: [PATCH 04/37] Example API --- pytorch_lightning/plugins/training_type/ddp.py | 6 ++++-- pytorch_lightning/plugins/training_type/deepspeed.py | 11 ++++++----- pytorch_lightning/plugins/training_type/parallel.py | 4 +++- .../plugins/training_type/training_type_plugin.py | 5 ++--- .../trainer/connectors/accelerator_connector.py | 9 +++++++++ 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 003d567b35fc0..7720a00e89a36 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,6 +31,7 @@ from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import ( @@ -75,14 +76,15 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_plugin=checkpoint_plugin) self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b96a53bee60b7..57aeac21b4abb 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,6 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.checkpoint.deepspeed import DeepSpeedCheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -114,7 +115,7 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[DeepSpeedCheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -276,12 +277,12 @@ def __init__( pin_memory = cpu_offload_use_pin_memory super().__init__( - parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment + parallel_devices=parallel_devices, + num_nodes=num_nodes, + cluster_environment=cluster_environment, + checkpoint_plugin=DeepSpeedCheckpointPlugin() if checkpoint_plugin is None else checkpoint_plugin ) - self.checkpoint_plugin = DeepSpeedPlugin() if checkpoint_plugin is None else checkpoint_plugin - self.checkpoint_plugin.training_type_plugin = self - self.config = self._load_config(config) if self.config is None: # User has not overridden config, set defaults diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 4186d697f21ac..b177f932741f0 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -21,6 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -34,8 +35,9 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None ): - super().__init__() + super().__init__(checkpoint_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cff3c55063bde..21014348b15a2 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,7 +26,6 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -37,11 +36,11 @@ class TrainingTypePlugin(Plugin, ABC): Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ - def __init__(self) -> None: + def __init__(self, checkpoint_plugin: Optional[CheckpointPlugin] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True - self._checkpoint_plugin = TorchCheckpointPlugin() + self._checkpoint_plugin = checkpoint_plugin @property def checkpoint_plugin(self) -> CheckpointPlugin: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0f8d69706a147..f22469a5efb43 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -50,6 +50,7 @@ TrainingTypePlugin, TrainingTypePluginsRegistry, ) +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin from pytorch_lightning.plugins.environments import ( ClusterEnvironment, KubeflowEnvironment, @@ -99,6 +100,7 @@ def __init__( precision, amp_type, amp_level, + checkpoint_plugin, plugins, ): # initialization @@ -134,6 +136,7 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None + self.checkpoint_plugin = checkpoint_plugin plugins = plugins if plugins is not None else [] @@ -666,6 +669,12 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra training_type.cluster_environment = self.cluster_environment self._cluster_environment = proxy(self.cluster_environment) + if hasattr(training_type, "checkpoint_plugin"): + if getattr(training_type, "checkpoint_plugin") is None: + training_type.checkpoint_plugin = TorchCheckpointPlugin() + # todo (sean): this probably shouldn't happen here + training_type.checkpoint_plugin.training_type_plugin = training_type + if hasattr(training_type, "num_nodes"): # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes From 7aa4e8c59d7ffed3db7b0fdd7d7cc4e94d6cc46a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Aug 2021 12:51:58 +0000 Subject: [PATCH 05/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp.py | 6 +++++- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/plugins/training_type/parallel.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7720a00e89a36..5aee14f8f1bab 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -84,7 +84,11 @@ def __init__( ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_plugin=checkpoint_plugin) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_plugin=checkpoint_plugin, + ) self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 57aeac21b4abb..8b3108ffa3e52 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -280,7 +280,7 @@ def __init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment, - checkpoint_plugin=DeepSpeedCheckpointPlugin() if checkpoint_plugin is None else checkpoint_plugin + checkpoint_plugin=DeepSpeedCheckpointPlugin() if checkpoint_plugin is None else checkpoint_plugin, ) self.config = self._load_config(config) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index b177f932741f0..9177f6dd95aff 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -35,7 +35,7 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None + checkpoint_plugin: Optional[CheckpointPlugin] = None, ): super().__init__(checkpoint_plugin) self.parallel_devices = parallel_devices From dffe088b96bedd1c7380b3ffbab67dbf04e4e896 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 5 Aug 2021 14:02:00 +0100 Subject: [PATCH 06/37] Update all constructors --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 +++++++- pytorch_lightning/plugins/training_type/dp.py | 11 +++++++++-- .../plugins/training_type/fully_sharded.py | 11 ++++++++--- pytorch_lightning/plugins/training_type/horovod.py | 11 +++++++++-- pytorch_lightning/plugins/training_type/ipu.py | 8 +++++++- .../plugins/training_type/single_device.py | 9 +++++++-- .../plugins/training_type/single_tpu.py | 12 +++++++++--- pytorch_lightning/plugins/training_type/tpu_spawn.py | 11 +++++++++-- 8 files changed, 65 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c39d35e8d1b6a..5d8224aee940d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,6 +24,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -63,13 +64,18 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Any, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_plugin=checkpoint_plugin, + ) if num_nodes is not None: rank_zero_deprecation( "Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. " diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index beedac2942ac6..2a2324538e508 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,6 +17,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -29,8 +30,14 @@ class DataParallelPlugin(ParallelPlugin): device and each gets a split of the data. """ - def __init__(self, parallel_devices: Optional[List[torch.device]]): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + def __init__( + self, + parallel_devices: Optional[List[torch.device]], + checkpoint_plugin: Optional[CheckpointPlugin] = None, + ): + super().__init__( + parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin + ) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index bbf6143ce6b66..f716fae855cfa 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Dict, Generator, List, Optional import torch -from torch import Tensor +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE @@ -41,6 +41,7 @@ def __init__( state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: ClusterEnvironment = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, ): """ Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -89,7 +90,11 @@ def __init__( (Defautl: True). """ - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_plugin=checkpoint_plugin, + ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu self.flatten_parameters = flatten_parameters diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 34fe429d89362..3003ed19ce3f1 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -33,8 +34,14 @@ class HorovodPlugin(ParallelPlugin): """Plugin for Horovod distributed training integration.""" - def __init__(self, parallel_devices: Optional[List[torch.device]] = None): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, + ): + super().__init__( + parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin + ) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 82a0d3db68e64..10a3ab64afee0 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage @@ -67,6 +68,7 @@ def __init__( autoreport_dir: Optional[str] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -83,7 +85,11 @@ def __init__( inference_opts: Optional ``poptorch.Options`` to override the default created options for validation/testing and predicting. """ - super().__init__(parallel_devices, cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_plugin=checkpoint_plugin, + ) if not _POPTORCH_AVAILABLE or not poptorch.ipuHardwareIsAvailable(): raise MisconfigurationException( "The IPU Accelerator requires IPU devices to run. " diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 5399cffe19f68..0e8cf406f3fed 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,6 +15,7 @@ import torch +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -22,8 +23,12 @@ class SingleDevicePlugin(TrainingTypePlugin): """Plugin that handles communication on a single device.""" - def __init__(self, device: torch.device): - super().__init__() + def __init__( + self, + device: torch.device, + checkpoint_plugin: Optional[CheckpointPlugin] = None, + ): + super().__init__(checkpoint_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index d83bd7ed8ba17..f26748a564efb 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict +from typing import Any, Dict, Optional from pytorch_lightning.core.decorators import parameter_validation +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -29,10 +30,15 @@ class SingleTPUPlugin(SingleDevicePlugin): """Plugin for training on a single TPU device.""" - def __init__(self, device: int, debug: bool = False): + def __init__( + self, + device: int, + debug: bool = False, + checkpoint_plugin: Optional[CheckpointPlugin] = None, + ): device = xm.xla_device(device) - super().__init__(device) + super().__init__(device=device, checkpoint_plugin=checkpoint_plugin) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index faec805773cb7..7d203966717c2 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn @@ -52,8 +53,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" - def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: - super().__init__(parallel_devices) + def __init__( + self, + parallel_devices: Optional[List[int]] = None, + checkpoint_plugin: Optional[CheckpointPlugin] = None, + debug: bool = False, + **_: Any + ) -> None: + super().__init__(parallel_devices=parallel_devices, checkpoint_plugin=checkpoint_plugin) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 From cacf0e5b49a6239c291e199221e0ef5e51942ccb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 6 Aug 2021 10:06:24 +0100 Subject: [PATCH 07/37] Move towards having the checkpoint plugin not require the plugin, and focus on IO --- .../plugins/checkpoint/checkpoint.py | 40 +----- .../plugins/checkpoint/deepspeed.py | 123 ------------------ pytorch_lightning/plugins/checkpoint/torch.py | 17 +-- .../plugins/training_type/deepspeed.py | 116 ++++++++++++++--- .../training_type/training_type_plugin.py | 47 +++---- .../connectors/accelerator_connector.py | 7 +- 6 files changed, 130 insertions(+), 220 deletions(-) delete mode 100644 pytorch_lightning/plugins/checkpoint/deepspeed.py diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 33d2a6cb9bf2b..fac6bbe547f48 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,32 +1,9 @@ from abc import ABC from pathlib import Path -from typing import Any, Dict, Mapping, Union - -from torch.nn import Module - -from pytorch_lightning import LightningModule +from typing import Any, Dict, Union class CheckpointPlugin(ABC): - def __init__(self): - self._training_type_plugin = None - - @property - def training_type_plugin(self) -> "TrainingTypePlugin": - return self._training_type_plugin - - @training_type_plugin.setter - def training_type_plugin(self, plugin) -> None: - self._training_type_plugin = plugin - - @property - def lightning_module(self) -> LightningModule: - return self.training_type_plugin.lightning_module - - @property - def model(self) -> Module: - return self.training_type_plugin.model - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -43,18 +20,3 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A Returns: The loaded checkpoint. """ - pass - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - """ - Given the loaded checkpoint file, loads the state dict into the model. - Args: - checkpoint: The loaded checkpoint file. - """ - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - """ - Given the loaded checkpoint file, loads the optimizer state dicts into optimizers. - Args: - checkpoint: The loaded checkpoint file. - """ diff --git a/pytorch_lightning/plugins/checkpoint/deepspeed.py b/pytorch_lightning/plugins/checkpoint/deepspeed.py deleted file mode 100644 index 1eb93b56673ba..0000000000000 --- a/pytorch_lightning/plugins/checkpoint/deepspeed.py +++ /dev/null @@ -1,123 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Mapping, Optional, Union - -import torch -from torch import Tensor - -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() -if _DEEPSPEED_AVAILABLE: - import deepspeed - - -class DeepSpeedCheckpointPlugin(CheckpointPlugin): - def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: - """Save model/training states as a checkpoint file through state-dump and file-write. - - Args: - checkpoint: The checkpoint state dictionary - filepath: write-target file's path - """ - if ( - self.training_type_plugin.zero_stage_3 - and self.training_type_plugin._multi_device - and self.training_type_plugin.is_global_zero - ): - warning_cache.warn( - "When saving the DeepSpeed Stage 3 checkpoint, " - "each worker will save a shard of the checkpoint within a directory. " - "If a single file is required after training, " - "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" - "deepspeed-zero-stage-3-single-file for instructions." - ) - # Use deepspeed's internal checkpointing function to handle partitioned weights across processes - # dump states as a checkpoint dictionary object - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] - checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.model.save_checkpoint(filepath, client_state=checkpoint) - - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: - if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: - # Broadcast to ensure we load from the rank 0 checkpoint - # This doesn't have to be the case when using deepspeed sharded checkpointing - checkpoint_path = self.training_type_plugin.broadcast(checkpoint_path) - return super().load_checkpoint_file(checkpoint_path) - - # Rely on deepspeed to load the checkpoint and necessary information - from pytorch_lightning.trainer.states import TrainerFn - - is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING - _, client_state = self.model.load_checkpoint( - checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting - ) - if client_state is None: - raise MisconfigurationException( - "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " - "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." - ) - return client_state - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` - if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: - self.training_type_plugin.model_to_device() - self._restore_zero_state(checkpoint) - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` - pass - - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """Returns model state.""" - model = self.lightning_module - return model.state_dict() - - def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: - """ - Overrides the normal load_state_dict behaviour in PyTorch to ensure - we gather parameters that may be sharded across processes before loading - the state dictionary when using ZeRO stage 3. - This is then automatically synced across processes. - - Args: - ckpt: The ckpt file. - """ - - def load(module: torch.nn.Module, prefix=""): - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - state_dict = ckpt["state_dict"] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if self.training_type_plugin.is_global_zero: - module._load_from_state_dict( - state_dict=state_dict, - prefix=prefix, - local_metadata=local_metadata, - strict=True, - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - error_msgs=error_msgs, - ) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(self.lightning_module, prefix="") diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 6d3d03e22f585..09b23d0ff6ac4 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,7 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Mapping, Union - -from torch import Tensor +from typing import Any, Dict, Union import pytorch_lightning as pl from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin @@ -24,16 +22,3 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - self.lightning_module.load_state_dict(checkpoint["state_dict"]) - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - optimizer_states = checkpoint["optimizer_states"] - for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """Returns model state.""" - model = self.lightning_module - return model.state_dict() diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8b3108ffa3e52..c30699d6108ab 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -27,7 +27,6 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin -from pytorch_lightning.plugins.checkpoint.deepspeed import DeepSpeedCheckpointPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config @@ -276,11 +275,13 @@ def __init__( offload_parameters = cpu_offload_params pin_memory = cpu_offload_use_pin_memory + if checkpoint_plugin is not None: + raise MisconfigurationException("DeepSpeed currently does not support passing a custom checkpoint plugin.") + super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment, - checkpoint_plugin=DeepSpeedCheckpointPlugin() if checkpoint_plugin is None else checkpoint_plugin, ) self.config = self._load_config(config) @@ -675,9 +676,50 @@ def deepspeed_engine(self): return self.model @property - def multi_device(self) -> bool: + def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 + def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: The checkpoint state dictionary + filepath: write-target file's path + """ + if self.zero_stage_3 and self._multi_device and self.is_global_zero: + # todo (sean): Add link to docs once docs are merged. + warning_cache.warn( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + "If a single file is required after training, see for instructions." + ) + # Use deepspeed's internal checkpointing function to handle partitioned weights across processes + # dump states as a checkpoint dictionary object + _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] + checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} + self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) + + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + if self.load_full_weights and self.zero_stage_3: + # Broadcast to ensure we load from the rank 0 checkpoint + # This doesn't have to be the case when using deepspeed sharded checkpointing + checkpoint_path = self.broadcast(checkpoint_path) + return super().load_checkpoint_file(checkpoint_path) + + # Rely on deepspeed to load the checkpoint and necessary information + from pytorch_lightning.trainer.states import TrainerFn + + is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + _, client_state = self.deepspeed_engine.load_checkpoint( + checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting + ) + if client_state is None: + raise MisconfigurationException( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " + "or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." + ) + return client_state + @property def lightning_restore_optimizer_and_schedulers(self) -> bool: # managed by DeepSpeed @@ -689,6 +731,62 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: ) return False + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` + if self.load_full_weights and self.zero_stage_3: + self.model_to_device() + self._restore_zero_state(checkpoint) + + def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: + """ + Overrides the normal load_state_dict behaviour in PyTorch to ensure + we gather parameters that may be sharded across processes before loading + the state dictionary when using ZeRO stage 3. + This is then automatically synced across processes. + + Args: + ckpt: The ckpt file. + """ + + def load(module: torch.nn.Module, prefix=""): + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + state_dict = ckpt["state_dict"] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if self.is_global_zero: + module._load_from_state_dict( + state_dict=state_dict, + prefix=prefix, + local_metadata=local_metadata, + strict=True, + missing_keys=missing_keys, + unexpected_keys=unexpected_keys, + error_msgs=error_msgs, + ) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self.lightning_module, prefix="") + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` + pass + def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: if self._original_accumulate_grad_batches is None: return super().update_global_step(total_batch_idx, current_global_step) @@ -727,15 +825,3 @@ def register_plugins(cls, plugin_registry: Dict) -> None: offload_params_device="nvme", offload_optimizer_device="nvme", ) - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - self.checkpoint_plugin.load_model_state_dict(checkpoint) - - def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: - return self.checkpoint_plugin.save_checkpoint(checkpoint) - - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: - return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - return self.checkpoint_plugin.load_optimizer_state_dict(checkpoint) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 21014348b15a2..8dd5f52db92e1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -152,6 +152,17 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ return self._results + def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + optimizer_states = checkpoint["optimizer_states"] + for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + def start_training(self, trainer: "pl.Trainer") -> None: # double dispatch to initiate the training loop self._results = trainer.run_stage() @@ -264,6 +275,20 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> """ return current_global_step + 1 + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + """Returns model state.""" + model = self.lightning_module + return model.state_dict() + + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ + return self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) + @contextlib.contextmanager def model_sharded_context(self) -> Generator: """ @@ -342,25 +367,3 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) Called in the training loop before anything happens for that batch. """ pass - - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) - - def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - self.checkpoint_plugin.load_model_state_dict(checkpoint) - - def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - self.checkpoint_plugin.load_optimizer_state_dict(checkpoint) - - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """Returns model state.""" - return self.checkpoint_plugin.lightning_module_state_dict() - - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: - """Save model/training states as a checkpoint file through state-dump and file-write. - - Args: - checkpoint: dict containing model and trainer state - filepath: write-target file's path - """ - self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index f22469a5efb43..207b8b196013b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -669,11 +669,8 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra training_type.cluster_environment = self.cluster_environment self._cluster_environment = proxy(self.cluster_environment) - if hasattr(training_type, "checkpoint_plugin"): - if getattr(training_type, "checkpoint_plugin") is None: - training_type.checkpoint_plugin = TorchCheckpointPlugin() - # todo (sean): this probably shouldn't happen here - training_type.checkpoint_plugin.training_type_plugin = training_type + if hasattr(training_type, "checkpoint_plugin") and getattr(training_type, "checkpoint_plugin") is None: + training_type.checkpoint_plugin = TorchCheckpointPlugin() if hasattr(training_type, "num_nodes"): # set num_nodes for training_type from trainer setting From 028ac384402f71af2109ec4c7d13fb74333366ef Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 6 Aug 2021 10:35:18 +0100 Subject: [PATCH 08/37] Remove import --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 207b8b196013b..001890ecc5474 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -100,7 +100,6 @@ def __init__( precision, amp_type, amp_level, - checkpoint_plugin, plugins, ): # initialization @@ -136,7 +135,6 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None - self.checkpoint_plugin = checkpoint_plugin plugins = plugins if plugins is not None else [] From 99c7a46cf5c83bfd28cdc5a0f29d8af62da7754c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 6 Aug 2021 11:35:19 +0100 Subject: [PATCH 09/37] Fix tests --- tests/accelerators/test_cpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 1a584ed444758..83a5b00551eac 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -9,6 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -201,7 +202,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu")) + plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointPlugin()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch From 3adc486bbdae487bf43f609f91c46fe94d317332 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 6 Aug 2021 12:39:32 +0100 Subject: [PATCH 10/37] Change name --- pytorch_lightning/plugins/__init__.py | 4 ++++ pytorch_lightning/plugins/checkpoint/__init__.py | 2 ++ pytorch_lightning/plugins/checkpoint/checkpoint.py | 2 +- pytorch_lightning/plugins/checkpoint/torch.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/deepspeed.py | 4 ++-- pytorch_lightning/plugins/training_type/dp.py | 4 ++-- pytorch_lightning/plugins/training_type/fully_sharded.py | 4 ++-- pytorch_lightning/plugins/training_type/horovod.py | 4 ++-- pytorch_lightning/plugins/training_type/ipu.py | 4 ++-- pytorch_lightning/plugins/training_type/parallel.py | 4 ++-- pytorch_lightning/plugins/training_type/single_device.py | 4 ++-- pytorch_lightning/plugins/training_type/single_tpu.py | 4 ++-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- .../plugins/training_type/training_type_plugin.py | 6 +++--- .../trainer/connectors/accelerator_connector.py | 4 ++-- tests/accelerators/test_cpu.py | 4 ++-- 18 files changed, 38 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index d6434e84adae7..f8e8dedaccae5 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 call_training_type_register_plugins, TrainingTypePluginsRegistry, @@ -29,6 +31,8 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin __all__ = [ + "CheckpointIOPlugin", + "TorchCheckpointIOPlugin", "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", diff --git a/pytorch_lightning/plugins/checkpoint/__init__.py b/pytorch_lightning/plugins/checkpoint/__init__.py index e69de29bb2d1d..cf936291c7a26 100644 --- a/pytorch_lightning/plugins/checkpoint/__init__.py +++ b/pytorch_lightning/plugins/checkpoint/__init__.py @@ -0,0 +1,2 @@ +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin # noqa: F401 +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index fac6bbe547f48..ce1693ff1a138 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Union -class CheckpointPlugin(ABC): +class CheckpointIOPlugin(ABC): def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 09b23d0ff6ac4..9bf850f7cd41a 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -2,13 +2,13 @@ from typing import Any, Dict, Union import pytorch_lightning as pl -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -class TorchCheckpointPlugin(CheckpointPlugin): +class TorchCheckpointIOPlugin(CheckpointIOPlugin): def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object try: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 5aee14f8f1bab..d4daf3dc6f9f1 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,7 +31,7 @@ from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import ( @@ -77,7 +77,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5d8224aee940d..ae1ea6a25716e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,7 +24,7 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn @@ -64,7 +64,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index c30699d6108ab..5e5789baf2372 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,7 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config @@ -114,7 +114,7 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 2a2324538e508..2fe80a3ee9c3f 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,7 +17,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -33,7 +33,7 @@ class DataParallelPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]], - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): super().__init__( parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index f716fae855cfa..10ac61ff5decf 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -16,7 +16,7 @@ import torch -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE @@ -41,7 +41,7 @@ def __init__( state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: ClusterEnvironment = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): """ Plugin for Fully Sharded Data Parallel provided by FairScale. diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 3003ed19ce3f1..3657487fb237d 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -37,7 +37,7 @@ class HorovodPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): super().__init__( parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 10a3ab64afee0..35283e0680b62 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,7 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage @@ -68,7 +68,7 @@ def __init__( autoreport_dir: Optional[str] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 9177f6dd95aff..df2f1dffde0cd 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -35,7 +35,7 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): super().__init__(checkpoint_plugin) self.parallel_devices = parallel_devices diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 0e8cf406f3fed..da70afb0efc6f 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -26,7 +26,7 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__( self, device: torch.device, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): super().__init__(checkpoint_plugin) self.device: torch.device = device diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index f26748a564efb..f4654e91bc11c 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Optional from pytorch_lightning.core.decorators import parameter_validation -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -34,7 +34,7 @@ def __init__( self, device: int, debug: bool = False, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): device = xm.xla_device(device) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7d203966717c2..de92b2bf561f4 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn @@ -56,7 +56,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def __init__( self, parallel_devices: Optional[List[int]] = None, - checkpoint_plugin: Optional[CheckpointPlugin] = None, + checkpoint_plugin: Optional[CheckpointIOPlugin] = None, debug: bool = False, **_: Any ) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 8dd5f52db92e1..771ca34274428 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -36,14 +36,14 @@ class TrainingTypePlugin(Plugin, ABC): Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ - def __init__(self, checkpoint_plugin: Optional[CheckpointPlugin] = None) -> None: + def __init__(self, checkpoint_plugin: Optional[CheckpointIOPlugin] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True self._checkpoint_plugin = checkpoint_plugin @property - def checkpoint_plugin(self) -> CheckpointPlugin: + def checkpoint_plugin(self) -> CheckpointIOPlugin: return self._checkpoint_plugin @checkpoint_plugin.setter diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 001890ecc5474..aa548e8978db8 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -50,7 +50,7 @@ TrainingTypePlugin, TrainingTypePluginsRegistry, ) -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin from pytorch_lightning.plugins.environments import ( ClusterEnvironment, KubeflowEnvironment, @@ -668,7 +668,7 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra self._cluster_environment = proxy(self.cluster_environment) if hasattr(training_type, "checkpoint_plugin") and getattr(training_type, "checkpoint_plugin") is None: - training_type.checkpoint_plugin = TorchCheckpointPlugin() + training_type.checkpoint_plugin = TorchCheckpointIOPlugin() if hasattr(training_type, "num_nodes"): # set num_nodes for training_type from trainer setting diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 83a5b00551eac..6d92b882bc368 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointPlugin +from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -202,7 +202,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointPlugin()) + plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointIOPlugin()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch From b7d5b55e206377719e1c12844fd98830947e0ddc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 9 Aug 2021 10:11:23 +0100 Subject: [PATCH 11/37] Cleanups --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 6 ++++-- pytorch_lightning/plugins/training_type/deepspeed.py | 12 ++++++++---- .../plugins/training_type/training_type_plugin.py | 6 ++++-- .../trainer/connectors/accelerator_connector.py | 4 ---- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index ce1693ff1a138..2621d41501554 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,10 +1,11 @@ -from abc import ABC +from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, Union class CheckpointIOPlugin(ABC): - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + @abstractmethod + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -12,6 +13,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: filepath: write-target file's path """ + @abstractmethod def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 5e5789baf2372..1af959fdd3287 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -114,7 +114,6 @@ def __init__( num_nodes: Optional[int] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -275,9 +274,6 @@ def __init__( offload_parameters = cpu_offload_params pin_memory = cpu_offload_use_pin_memory - if checkpoint_plugin is not None: - raise MisconfigurationException("DeepSpeed currently does not support passing a custom checkpoint plugin.") - super().__init__( parallel_devices=parallel_devices, num_nodes=num_nodes, @@ -825,3 +821,11 @@ def register_plugins(cls, plugin_registry: Dict) -> None: offload_params_device="nvme", offload_optimizer_device="nvme", ) + + @property + def checkpoint_plugin(self) -> CheckpointIOPlugin: + return self._checkpoint_plugin + + @checkpoint_plugin.setter + def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 771ca34274428..1790118b83f8d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -24,6 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins import TorchCheckpointIOPlugin from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT @@ -39,15 +40,16 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self, checkpoint_plugin: Optional[CheckpointIOPlugin] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None + checkpoint_plugin = checkpoint_plugin if checkpoint_plugin is not None else TorchCheckpointIOPlugin() + self._checkpoint_plugin: CheckpointIOPlugin = checkpoint_plugin self._call_configure_sharded_model_hook = True - self._checkpoint_plugin = checkpoint_plugin @property def checkpoint_plugin(self) -> CheckpointIOPlugin: return self._checkpoint_plugin @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin) -> None: + def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: self._checkpoint_plugin = plugin def connect(self, model: Module) -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index aa548e8978db8..0f8d69706a147 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -50,7 +50,6 @@ TrainingTypePlugin, TrainingTypePluginsRegistry, ) -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin from pytorch_lightning.plugins.environments import ( ClusterEnvironment, KubeflowEnvironment, @@ -667,9 +666,6 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra training_type.cluster_environment = self.cluster_environment self._cluster_environment = proxy(self.cluster_environment) - if hasattr(training_type, "checkpoint_plugin") and getattr(training_type, "checkpoint_plugin") is None: - training_type.checkpoint_plugin = TorchCheckpointIOPlugin() - if hasattr(training_type, "num_nodes"): # set num_nodes for training_type from trainer setting training_type.num_nodes = self.num_nodes From 0a0a0689566d93573387935ac39f303f1f552cf3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 9 Aug 2021 12:18:16 +0100 Subject: [PATCH 12/37] Fixes/Cleanups --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 4 ++-- pytorch_lightning/plugins/checkpoint/torch.py | 7 +++---- .../plugins/training_type/training_type_plugin.py | 5 ++++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 2621d41501554..63f4ac2b0adbb 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -5,12 +5,12 @@ class CheckpointIOPlugin(ABC): @abstractmethod - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state - filepath: write-target file's path + path: write-target path """ @abstractmethod diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 9bf850f7cd41a..6d61696594ade 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -9,16 +9,15 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: - # dump states as a checkpoint dictionary object + def save_checkpoint(self, checkpoint: Dict[str, Any], path: str) -> None: try: # write the checkpoint dictionary on the file - atomic_save(checkpoint, filepath) + atomic_save(checkpoint, path) except AttributeError as err: key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") - atomic_save(checkpoint, filepath) + atomic_save(checkpoint, path) def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 1790118b83f8d..ab7c7a161638e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -289,7 +289,10 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: checkpoint: dict containing model and trainer state filepath: write-target file's path """ - return self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) + # dump states as a checkpoint dictionary object + checkpoint = self.on_save(checkpoint) + if self.is_global_zero: + return self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) @contextlib.contextmanager def model_sharded_context(self) -> Generator: From 97fb2a2e813caa4b961cc64bf03ef89fd61415e8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 9 Aug 2021 12:30:06 +0100 Subject: [PATCH 13/37] Use property --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index ab7c7a161638e..3c478346a3bb7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -291,7 +291,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """ # dump states as a checkpoint dictionary object checkpoint = self.on_save(checkpoint) - if self.is_global_zero: + if self.should_rank_save_checkpoint: return self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) @contextlib.contextmanager From 402156e983c84b0ec2444f9800f6759c8fa8a1ce Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 9 Aug 2021 13:42:12 +0100 Subject: [PATCH 14/37] Fixes to signature --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 4 ++-- pytorch_lightning/plugins/checkpoint/torch.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 63f4ac2b0adbb..1543215c1ca17 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -14,11 +14,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> """ @abstractmethod - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: - checkpoint_path: Path to checkpoint + path: Path to checkpoint Returns: The loaded checkpoint. """ diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 6d61696594ade..b05903a27a776 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -9,7 +9,7 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): - def save_checkpoint(self, checkpoint: Dict[str, Any], path: str) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) @@ -19,5 +19,5 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: str) -> None: rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") atomic_save(checkpoint, path) - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) + def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + return pl_load(path, map_location=(lambda storage, loc: storage)) From d7f567a98d2b2d13e2b0c3fe8f815d28d1f9ef55 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 9 Aug 2021 15:56:12 +0100 Subject: [PATCH 15/37] Add warning for TPU plugins that they do not support custom checkpoint plugins, add rudimentary test --- .../plugins/training_type/single_tpu.py | 14 ++++- .../plugins/training_type/tpu_spawn.py | 18 +++--- tests/plugins/test_checkpoint_plugin.py | 58 +++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) create mode 100644 tests/plugins/test_checkpoint_plugin.py diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index f4654e91bc11c..2e9987bb542e9 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Dict from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -34,11 +35,10 @@ def __init__( self, device: int, debug: bool = False, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): device = xm.xla_device(device) - super().__init__(device=device, checkpoint_plugin=checkpoint_plugin) + super().__init__(device=device) self.debug = debug self.tpu_local_core_rank = 0 @@ -80,3 +80,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) + + @property + def checkpoint_plugin(self) -> CheckpointIOPlugin: + return self._checkpoint_plugin + + @checkpoint_plugin.setter + def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + raise MisconfigurationException("TPU Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index de92b2bf561f4..c72d8069cc72a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -53,14 +53,8 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" - def __init__( - self, - parallel_devices: Optional[List[int]] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, - debug: bool = False, - **_: Any - ) -> None: - super().__init__(parallel_devices=parallel_devices, checkpoint_plugin=checkpoint_plugin) + def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: + super().__init__(parallel_devices=parallel_devices) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -352,3 +346,11 @@ def should_rank_save_checkpoint(self) -> bool: @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("tpu_spawn_debug", cls, description="TPUSpawn Plugin with `debug` as True", debug=True) + + @property + def checkpoint_plugin(self) -> CheckpointIOPlugin: + return self._checkpoint_plugin + + @checkpoint_plugin.setter + def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") diff --git a/tests/plugins/test_checkpoint_plugin.py b/tests/plugins/test_checkpoint_plugin.py new file mode 100644 index 0000000000000..635ff96dd37b4 --- /dev/null +++ b/tests/plugins/test_checkpoint_plugin.py @@ -0,0 +1,58 @@ +from pathlib import Path +from typing import Any, Dict, Union + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import CheckpointIOPlugin, SingleDevicePlugin, TorchCheckpointIOPlugin +from tests.helpers.boring_model import BoringModel + + +class CustomCheckpointPlugin(CheckpointIOPlugin): + save_checkpoint_called: bool = False + load_checkpoint_file_called: bool = False + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + self.save_checkpoint_called = True + torch.save(checkpoint, path) + + def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + self.load_checkpoint_file_called = True + return torch.load(path) + + +class CustomTorchCheckpointIOPlugin(TorchCheckpointIOPlugin): + save_checkpoint_called: bool = False + load_checkpoint_file_called: bool = False + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + self.save_checkpoint_called = True + super().save_checkpoint(checkpoint, path) + + def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + self.load_checkpoint_file_called = True + return super().load_checkpoint_file(path) + + +@pytest.mark.parametrize("checkpoint_plugin", [CustomTorchCheckpointIOPlugin(), CustomCheckpointPlugin()]) +def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): + """ + Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading. + """ + + ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) + + model = BoringModel() + device = torch.device("cpu") + trainer = Trainer( + default_root_dir=tmpdir, + plugins=SingleDevicePlugin(device, checkpoint_plugin=checkpoint_plugin), + callbacks=ck, + max_epochs=1, + ) + trainer.fit(model) + assert checkpoint_plugin.save_checkpoint_called + trainer.test(model, ckpt_path=ck.last_model_path) + assert checkpoint_plugin.load_checkpoint_file_called From fcc24b44a3089853cfa945d5e6e26bce7724d5bc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 10 Aug 2021 10:20:44 +0100 Subject: [PATCH 16/37] Cleanup API, introduce storage options --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 10 +++++++--- pytorch_lightning/plugins/checkpoint/torch.py | 12 ++++++++---- .../plugins/training_type/training_type_plugin.py | 2 +- tests/plugins/test_checkpoint_plugin.py | 6 +++--- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 1543215c1ca17..851173f46c917 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,24 +1,28 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Mapping, Optional, Union class CheckpointIOPlugin(ABC): @abstractmethod - def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Mapping] = None + ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path + storage_options: Optional parameters when saving the model/training states. """ @abstractmethod - def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Mapping] = None) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint + storage_options: Optional parameters when loading the model/training states. Returns: The loaded checkpoint. """ diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index b05903a27a776..31ed28625786b 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Callable, Dict, Mapping, Optional, Union import pytorch_lightning as pl from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin @@ -9,7 +9,9 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): - def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Mapping] = None + ) -> None: try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) @@ -19,5 +21,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") atomic_save(checkpoint, path) - def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: - return pl_load(path, map_location=(lambda storage, loc: storage)) + def load_checkpoint( + self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + ) -> Dict[str, Any]: + return pl_load(path, map_location=map_location) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 3c478346a3bb7..e1c4fe4b63281 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -155,7 +155,7 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: return self._results def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return self.checkpoint_plugin.load_checkpoint_file(checkpoint_path) + return self.checkpoint_plugin.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) diff --git a/tests/plugins/test_checkpoint_plugin.py b/tests/plugins/test_checkpoint_plugin.py index 635ff96dd37b4..f0acb1b42f06c 100644 --- a/tests/plugins/test_checkpoint_plugin.py +++ b/tests/plugins/test_checkpoint_plugin.py @@ -18,7 +18,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> self.save_checkpoint_called = True torch.save(checkpoint, path) - def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]: self.load_checkpoint_file_called = True return torch.load(path) @@ -31,9 +31,9 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> self.save_checkpoint_called = True super().save_checkpoint(checkpoint, path) - def load_checkpoint_file(self, path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]: self.load_checkpoint_file_called = True - return super().load_checkpoint_file(path) + return super().load_checkpoint(path) @pytest.mark.parametrize("checkpoint_plugin", [CustomTorchCheckpointIOPlugin(), CustomCheckpointPlugin()]) From 442127637b8f8a0896aea731d442390378e31352 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 10 Aug 2021 16:01:45 +0100 Subject: [PATCH 17/37] Update signature to be more general --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 6 +++--- pytorch_lightning/plugins/checkpoint/torch.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 851173f46c917..e55d8ace9078c 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Optional, Union class CheckpointIOPlugin(ABC): @abstractmethod def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Mapping] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -17,7 +17,7 @@ def save_checkpoint( """ @abstractmethod - def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Mapping] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 31ed28625786b..a580746b4d0f1 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Dict, Mapping, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin @@ -10,7 +10,7 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Mapping] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: try: # write the checkpoint dictionary on the file From d84cce1d9a98ba15474cfa0b4cb766b0c0055795 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Aug 2021 10:39:36 +0100 Subject: [PATCH 18/37] Address feedback, add test for support check --- .../plugins/checkpoint/checkpoint.py | 11 ++++-- ...plugin.py => test_checkpoint_io_plugin.py} | 36 +++++++++++++++---- 2 files changed, 38 insertions(+), 9 deletions(-) rename tests/plugins/{test_checkpoint_plugin.py => test_checkpoint_io_plugin.py} (53%) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index e55d8ace9078c..32e91d1cd5c99 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,12 +1,15 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TypeVar, Union + +TSaveStorageOptions = TypeVar("S") +TLoadStorageOptions = TypeVar("T") class CheckpointIOPlugin(ABC): @abstractmethod def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -17,7 +20,9 @@ def save_checkpoint( """ @abstractmethod - def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint( + self, path: Union[str, Path], storage_options: Optional[TLoadStorageOptions] = None + ) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/tests/plugins/test_checkpoint_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py similarity index 53% rename from tests/plugins/test_checkpoint_plugin.py rename to tests/plugins/test_checkpoint_io_plugin.py index f0acb1b42f06c..fe1e4f7ab3dfe 100644 --- a/tests/plugins/test_checkpoint_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -1,12 +1,21 @@ from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Callable, Dict, Optional, Union +from unittest import mock import pytest import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.plugins import CheckpointIOPlugin, SingleDevicePlugin, TorchCheckpointIOPlugin +from pytorch_lightning.plugins import ( + CheckpointIOPlugin, + DeepSpeedPlugin, + SingleDevicePlugin, + TorchCheckpointIOPlugin, + TPUSpawnPlugin, +) +from pytorch_lightning.plugins.checkpoint.checkpoint import TLoadStorageOptions, TSaveStorageOptions +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -14,11 +23,15 @@ class CustomCheckpointPlugin(CheckpointIOPlugin): save_checkpoint_called: bool = False load_checkpoint_file_called: bool = False - def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None + ) -> None: self.save_checkpoint_called = True torch.save(checkpoint, path) - def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint( + self, path: Union[str, Path], storage_options: Optional[TLoadStorageOptions] = None + ) -> Dict[str, Any]: self.load_checkpoint_file_called = True return torch.load(path) @@ -27,11 +40,15 @@ class CustomTorchCheckpointIOPlugin(TorchCheckpointIOPlugin): save_checkpoint_called: bool = False load_checkpoint_file_called: bool = False - def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None: + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: self.save_checkpoint_called = True super().save_checkpoint(checkpoint, path) - def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint( + self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + ) -> Dict[str, Any]: self.load_checkpoint_file_called = True return super().load_checkpoint(path) @@ -56,3 +73,10 @@ def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): assert checkpoint_plugin.save_checkpoint_called trainer.test(model, ckpt_path=ck.last_model_path) assert checkpoint_plugin.load_checkpoint_file_called + + +@mock.patch("pytorch_lightning.utilities._DEEPSPEED_AVAILABLE", return_value=True) +@pytest.mark.parametrize("plugin", [DeepSpeedPlugin(), TPUSpawnPlugin()]) +def test_no_checkpoint_io_plugin_support(mock_deepspeed, plugin): + with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): + plugin.checkpoint_plugin = CustomTorchCheckpointIOPlugin() From b7f37eec8c72559373da1d481641ecfc22d61141 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Aug 2021 10:42:48 +0100 Subject: [PATCH 19/37] Add CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77163b62457b6..f499f8c04f3c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) +- Added `CheckpointIOPlugin` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) From 49086cc9ea12edebd2539586f7481b4bcaa8072e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Aug 2021 11:26:37 +0100 Subject: [PATCH 20/37] fix tests --- tests/plugins/test_checkpoint_io_plugin.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index fe1e4f7ab3dfe..097ab32dfcff6 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -1,6 +1,5 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Union -from unittest import mock import pytest import torch @@ -17,6 +16,7 @@ from pytorch_lightning.plugins.checkpoint.checkpoint import TLoadStorageOptions, TSaveStorageOptions from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf class CustomCheckpointPlugin(CheckpointIOPlugin): @@ -75,8 +75,7 @@ def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): assert checkpoint_plugin.load_checkpoint_file_called -@mock.patch("pytorch_lightning.utilities._DEEPSPEED_AVAILABLE", return_value=True) -@pytest.mark.parametrize("plugin", [DeepSpeedPlugin(), TPUSpawnPlugin()]) -def test_no_checkpoint_io_plugin_support(mock_deepspeed, plugin): +@pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) +def test_no_checkpoint_io_plugin_support(plugin_cls): with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): - plugin.checkpoint_plugin = CustomTorchCheckpointIOPlugin() + plugin_cls().checkpoint_plugin = CustomTorchCheckpointIOPlugin() From 936f65af52c52dcafa314a3c8134f5ad4702a858 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Aug 2021 12:18:52 +0100 Subject: [PATCH 21/37] change name --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 32e91d1cd5c99..b27bb84f86927 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -2,8 +2,8 @@ from pathlib import Path from typing import Any, Dict, Optional, TypeVar, Union -TSaveStorageOptions = TypeVar("S") -TLoadStorageOptions = TypeVar("T") +TSaveStorageOptions = TypeVar("TSaveStorageOptions") +TLoadStorageOptions = TypeVar("TLoadStorageOptions") class CheckpointIOPlugin(ABC): From 049a67620bb4f98c2d716a2acc49a378553c09e9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 11 Aug 2021 14:46:22 +0100 Subject: [PATCH 22/37] Fix mypy --- pytorch_lightning/plugins/checkpoint/torch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index a580746b4d0f1..cd5cd6a0a0a33 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Dict, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin, TLoadStorageOptions, TSaveStorageOptions from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -10,7 +10,7 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None ) -> None: try: # write the checkpoint dictionary on the file @@ -22,6 +22,6 @@ def save_checkpoint( atomic_save(checkpoint, path) def load_checkpoint( - self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + self, path: Union[str, Path], map_location: Optional[TLoadStorageOptions] = lambda storage, loc: storage ) -> Dict[str, Any]: return pl_load(path, map_location=map_location) From 1ff091256e7d29f827dc3f3758c8fce38f895c09 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Aug 2021 12:01:39 +0100 Subject: [PATCH 23/37] Reviews --- pytorch_lightning/plugins/checkpoint/torch.py | 8 ++++---- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/deepspeed.py | 4 +++- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index cd5cd6a0a0a33..a580746b4d0f1 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin, TLoadStorageOptions, TSaveStorageOptions +from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -10,7 +10,7 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: try: # write the checkpoint dictionary on the file @@ -22,6 +22,6 @@ def save_checkpoint( atomic_save(checkpoint, path) def load_checkpoint( - self, path: Union[str, Path], map_location: Optional[TLoadStorageOptions] = lambda storage, loc: storage + self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: return pl_load(path, map_location=map_location) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ae1ea6a25716e..61c442e3d881e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -63,7 +63,7 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_plugin: Optional[CheckpointIOPlugin] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 1af959fdd3287..853f3ff8af7e0 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -687,7 +687,9 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory. " - "If a single file is required after training, see for instructions." + "If a single file is required after training, " + "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" + "deepspeed-zero-stage-3-single-file for instructions." ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 10ac61ff5decf..117bda7b2126d 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -40,7 +40,7 @@ def __init__( min_num_params: int = 1e8, state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_plugin: Optional[CheckpointIOPlugin] = None, ): """ From 1841d3b7e9c2db70ea113aa6c944e274350e4ee8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Aug 2021 15:23:38 +0100 Subject: [PATCH 24/37] Add ability to pass checkpoint plugin through the trainer --- .../trainer/connectors/accelerator_connector.py | 16 +++++++++++++++- tests/plugins/test_checkpoint_io_plugin.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0f8d69706a147..9da9e31cbc6c4 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, + CheckpointIOPlugin, DataParallelPlugin, DDP2Plugin, DDPFullyShardedPlugin, @@ -134,6 +135,7 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None + self._checkpoint_plugin: Optional[CheckpointIOPlugin] = None plugins = plugins if plugins is not None else [] @@ -274,6 +276,7 @@ def _set_devices_if_none(self) -> None: def handle_given_plugins(self) -> None: training_type = None + checkpoint = None precision = None cluster_environment = None @@ -310,7 +313,14 @@ def handle_given_plugins(self) -> None: "You can only specify one precision and one training type plugin." f" Found more than 1 precision plugin: {type(plug).__name__}" ) - + elif isinstance(plug, CheckpointIOPlugin): + if checkpoint is None: + checkpoint = plug + else: + raise MisconfigurationException( + "You can only specify one checkpoint plugin and one training type plugin." + f" Found more than 1 checkpoint plugin: {type(plug).__name__}" + ) elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: cluster_environment = plug @@ -325,6 +335,7 @@ def handle_given_plugins(self) -> None: self._training_type_plugin = training_type self._precision_plugin = precision + self._checkpoint_plugin = checkpoint self._cluster_environment = cluster_environment or self.select_cluster_environment() @property @@ -341,6 +352,9 @@ def training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + # attach checkpoint plugin to the training type plugin + if self._checkpoint_plugin is not None: + self._training_type_plugin.checkpoint_plugin = self._checkpoint_plugin self._training_type_plugin_resolved = True return self._training_type_plugin diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 097ab32dfcff6..e716eb289211d 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -74,6 +74,23 @@ def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): trainer.test(model, ckpt_path=ck.last_model_path) assert checkpoint_plugin.load_checkpoint_file_called + checkpoint_plugin.save_checkpoint_called = False + checkpoint_plugin.load_checkpoint_file_called = False + ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) + + model = BoringModel() + device = torch.device("cpu") + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[SingleDevicePlugin(device), checkpoint_plugin], + callbacks=ck, + max_epochs=1, + ) + trainer.fit(model) + assert checkpoint_plugin.save_checkpoint_called + trainer.test(model, ckpt_path=ck.last_model_path) + assert checkpoint_plugin.load_checkpoint_file_called + @pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) def test_no_checkpoint_io_plugin_support(plugin_cls): From b909dfe3baaa1c1f482e242c318613aedc6b2447 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Aug 2021 15:59:48 +0100 Subject: [PATCH 25/37] Add constraints --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index b27bb84f86927..158d8fcfbfe6c 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Mapping, Optional, TypeVar, Union -TSaveStorageOptions = TypeVar("TSaveStorageOptions") -TLoadStorageOptions = TypeVar("TLoadStorageOptions") +TSaveStorageOptions = TypeVar("TSaveStorageOptions", Mapping, Callable) +TLoadStorageOptions = TypeVar("TLoadStorageOptions", Mapping, Callable) class CheckpointIOPlugin(ABC): From 50b11b5c064b07a9854af82803ea1729b2b5031a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 12 Aug 2021 17:07:46 +0100 Subject: [PATCH 26/37] Match signature to see if mypy works --- pytorch_lightning/plugins/checkpoint/torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index a580746b4d0f1..4452509a03993 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -22,6 +22,6 @@ def save_checkpoint( atomic_save(checkpoint, path) def load_checkpoint( - self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + self, path: Union[str, Path], storage_options: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: - return pl_load(path, map_location=map_location) + return pl_load(path, map_location=storage_options) From 9e16e34631334da2604339bc012c829b22e58d63 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 09:54:24 +0100 Subject: [PATCH 27/37] Address review points --- .../plugins/checkpoint/__init__.py | 13 ++++ .../plugins/checkpoint/checkpoint.py | 29 +++++++++ pytorch_lightning/plugins/checkpoint/torch.py | 13 ++++ tests/plugins/test_checkpoint_io_plugin.py | 61 ++++++++----------- 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/__init__.py b/pytorch_lightning/plugins/checkpoint/__init__.py index cf936291c7a26..c4ec8b8c89573 100644 --- a/pytorch_lightning/plugins/checkpoint/__init__.py +++ b/pytorch_lightning/plugins/checkpoint/__init__.py @@ -1,2 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin # noqa: F401 from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index 158d8fcfbfe6c..e204934c0f3c2 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, Mapping, Optional, TypeVar, Union @@ -7,6 +20,22 @@ class CheckpointIOPlugin(ABC): + """ + Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. + + Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIOPlugin`` but may + require particular handling depending the plugin. + + In addition, you can pass a custom ``CheckpointIOPlugin`` by extending this class and passing it + to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIOPlugin()])``. + + .. note:: + + For some plugins it is not possible to use a custom checkpoint plugin as checkpointing logic is not + modifiable. + + """ + @abstractmethod def save_checkpoint( self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 4452509a03993..f0758bdc0365a 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path from typing import Any, Callable, Dict, Optional, Union diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index e716eb289211d..b4d0a5f65804a 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -1,18 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Dict, Optional, Union +from unittest.mock import MagicMock import pytest import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.plugins import ( - CheckpointIOPlugin, - DeepSpeedPlugin, - SingleDevicePlugin, - TorchCheckpointIOPlugin, - TPUSpawnPlugin, -) +from pytorch_lightning.plugins import CheckpointIOPlugin, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin from pytorch_lightning.plugins.checkpoint.checkpoint import TLoadStorageOptions, TSaveStorageOptions from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -36,28 +44,12 @@ def load_checkpoint( return torch.load(path) -class CustomTorchCheckpointIOPlugin(TorchCheckpointIOPlugin): - save_checkpoint_called: bool = False - load_checkpoint_file_called: bool = False - - def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None - ) -> None: - self.save_checkpoint_called = True - super().save_checkpoint(checkpoint, path) - - def load_checkpoint( - self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage - ) -> Dict[str, Any]: - self.load_checkpoint_file_called = True - return super().load_checkpoint(path) - - -@pytest.mark.parametrize("checkpoint_plugin", [CustomTorchCheckpointIOPlugin(), CustomCheckpointPlugin()]) -def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): +def test_checkpoint_plugin_called(tmpdir): """ Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading. """ + checkpoint_plugin = CustomCheckpointPlugin() + checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointPlugin) ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) @@ -70,12 +62,11 @@ def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): max_epochs=1, ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint_called + assert checkpoint_plugin.save_checkpoint.call_count == 3 trainer.test(model, ckpt_path=ck.last_model_path) - assert checkpoint_plugin.load_checkpoint_file_called + checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") - checkpoint_plugin.save_checkpoint_called = False - checkpoint_plugin.load_checkpoint_file_called = False + checkpoint_plugin.reset_mock() ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() @@ -87,12 +78,14 @@ def test_checkpoint_plugin_called(tmpdir, checkpoint_plugin): max_epochs=1, ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint_called + assert checkpoint_plugin.save_checkpoint.call_count == 3 + trainer.test(model, ckpt_path=ck.last_model_path) - assert checkpoint_plugin.load_checkpoint_file_called + checkpoint_plugin.load_checkpoint.assert_called_once() + checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") @pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) def test_no_checkpoint_io_plugin_support(plugin_cls): with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): - plugin_cls().checkpoint_plugin = CustomTorchCheckpointIOPlugin() + plugin_cls().checkpoint_plugin = CustomCheckpointPlugin() From 642e6fa479120ae5ff1420c3bdcb6d078541fab7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 09:58:01 +0100 Subject: [PATCH 28/37] Revert changes to typing --- pytorch_lightning/plugins/checkpoint/checkpoint.py | 11 +++-------- pytorch_lightning/plugins/checkpoint/torch.py | 2 ++ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index e204934c0f3c2..e05b2fd6d9862 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -13,10 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, Mapping, Optional, TypeVar, Union - -TSaveStorageOptions = TypeVar("TSaveStorageOptions", Mapping, Callable) -TLoadStorageOptions = TypeVar("TLoadStorageOptions", Mapping, Callable) +from typing import Any, Dict, Optional, Union class CheckpointIOPlugin(ABC): @@ -38,7 +35,7 @@ class CheckpointIOPlugin(ABC): @abstractmethod def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -49,9 +46,7 @@ def save_checkpoint( """ @abstractmethod - def load_checkpoint( - self, path: Union[str, Path], storage_options: Optional[TLoadStorageOptions] = None - ) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index f0758bdc0365a..54008bf878b9d 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -29,6 +29,8 @@ def save_checkpoint( # write the checkpoint dictionary on the file atomic_save(checkpoint, path) except AttributeError as err: + # todo (sean): is this try catch necessary still? + # https://github.com/PyTorchLightning/pytorch-lightning/pull/431 key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") From 5c9e973c7e381af47e77005b3315449cc60f4a4f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 10:54:25 +0100 Subject: [PATCH 29/37] Add docs/doc strings and API --- docs/source/advanced/checkpoint_io.rst | 52 +++++++++++++++++++ docs/source/api_references.rst | 12 +++++ docs/source/index.rst | 1 + .../plugins/checkpoint/checkpoint.py | 1 + pytorch_lightning/plugins/checkpoint/torch.py | 19 ++++++- tests/plugins/test_checkpoint_io_plugin.py | 7 +-- 6 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 docs/source/advanced/checkpoint_io.rst diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst new file mode 100644 index 0000000000000..9975878eff780 --- /dev/null +++ b/docs/source/advanced/checkpoint_io.rst @@ -0,0 +1,52 @@ +Custom Checkpointing IO +======================= + +.. warning:: The Checkpoint IO API is experimental and subject to change. + +Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIOPlugin``. This encapsulates the save/load logic +that is managed by the ``TrainingTypePlugin``. + +``CheckpointIOPlugin`` can be extended to include your custom save/load functionality to and from a path, with the object being passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. + +.. code-block:: python + + from pathlib import Path + from typing import Any, Dict, Optional, Union + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + from pytorch_lightning.plugins import CheckpointIOPlugin, SingleDevicePlugin + + + class CustomCheckpointPlugin(CheckpointIOPlugin): + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: + ... + + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: + ... + + + checkpoint_plugin = CustomCheckpointPlugin() + + # Pass into the Trainer object + model = MyModel() + trainer = Trainer( + plugins=[checkpoint_plugin], + callbacks=ModelCheckpoint(save_last=True), + ) + trainer.fit(model) + + # pass into TrainingTypePlugin + model = MyModel() + device = torch.device("cpu") + trainer = Trainer( + plugins=SingleDevicePlugin(device, checkpoint_plugin=checkpoint_plugin), + callbacks=ModelCheckpoint(save_last=True), + ) + trainer.fit(model) + +.. note:: + + Some ``TrainingTypePlugins`` do not support custom ``CheckpointIOPlugin`` as as checkpointing logic is not modifiable. diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 919aece8a72d3..632129ba5c981 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -127,6 +127,18 @@ Cluster Environments KubeflowEnvironment SLURMEnvironment +Checkpoint IO Plugins +^^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: pytorch_lightning.plugins.checkpoint + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + CheckpointIOPlugin + TorchCheckpointIOPlugin Profiler API ------------ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5e8e4352b57d8..105a1dfff54bf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ PyTorch Lightning Documentation advanced/multi_gpu advanced/advanced_gpu common/weights_loading + advanced/checkpoint_io common/optimizers advanced/profiler advanced/sequences diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/checkpoint/checkpoint.py index e05b2fd6d9862..cecb4b3c947ce 100644 --- a/pytorch_lightning/plugins/checkpoint/checkpoint.py +++ b/pytorch_lightning/plugins/checkpoint/checkpoint.py @@ -49,6 +49,7 @@ def save_checkpoint( def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. + Args: path: Path to checkpoint storage_options: Optional parameters when loading the model/training states. diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/checkpoint/torch.py index 54008bf878b9d..b5131d90a5357 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/checkpoint/torch.py @@ -22,6 +22,11 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): + """ + CheckpointIOPlugin that utilizes ``torch.save``/``torch.load`` to save and load checkpoints, + common for most use cases. + """ + def save_checkpoint( self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: @@ -37,6 +42,16 @@ def save_checkpoint( atomic_save(checkpoint, path) def load_checkpoint( - self, path: Union[str, Path], storage_options: Optional[Callable] = lambda storage, loc: storage + self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: - return pl_load(path, map_location=storage_options) + """ + Loads checkpoint using torch.load, with additional handling for fsspec remote loading of files. + + Args: + path: Path to checkpoint + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage + locations. + + Returns: The loaded checkpoint. + """ + return pl_load(path, map_location=map_location) diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index b4d0a5f65804a..9e289b857319a 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -21,7 +21,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import CheckpointIOPlugin, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin -from pytorch_lightning.plugins.checkpoint.checkpoint import TLoadStorageOptions, TSaveStorageOptions from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -32,14 +31,12 @@ class CustomCheckpointPlugin(CheckpointIOPlugin): load_checkpoint_file_called: bool = False def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[TSaveStorageOptions] = None + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: self.save_checkpoint_called = True torch.save(checkpoint, path) - def load_checkpoint( - self, path: Union[str, Path], storage_options: Optional[TLoadStorageOptions] = None - ) -> Dict[str, Any]: + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: self.load_checkpoint_file_called = True return torch.load(path) From 6361c8738eabed522ef4f922036503e5a939e812 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 12:20:39 +0100 Subject: [PATCH 30/37] Address feedback --- docs/source/api_references.rst | 2 +- pytorch_lightning/plugins/__init__.py | 4 ++-- pytorch_lightning/plugins/{checkpoint => io}/__init__.py | 4 ++-- .../{checkpoint/checkpoint.py => io/checkpoint_plugin.py} | 0 .../plugins/{checkpoint/torch.py => io/torch_plugin.py} | 2 +- pytorch_lightning/plugins/training_type/ddp.py | 2 +- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- pytorch_lightning/plugins/training_type/dp.py | 2 +- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- pytorch_lightning/plugins/training_type/horovod.py | 2 +- pytorch_lightning/plugins/training_type/ipu.py | 2 +- pytorch_lightning/plugins/training_type/parallel.py | 2 +- pytorch_lightning/plugins/training_type/single_device.py | 2 +- pytorch_lightning/plugins/training_type/single_tpu.py | 2 +- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- .../plugins/training_type/training_type_plugin.py | 2 +- tests/accelerators/test_cpu.py | 2 +- 18 files changed, 19 insertions(+), 19 deletions(-) rename pytorch_lightning/plugins/{checkpoint => io}/__init__.py (75%) rename pytorch_lightning/plugins/{checkpoint/checkpoint.py => io/checkpoint_plugin.py} (100%) rename pytorch_lightning/plugins/{checkpoint/torch.py => io/torch_plugin.py} (96%) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 632129ba5c981..91dda78bf3418 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -130,7 +130,7 @@ Cluster Environments Checkpoint IO Plugins ^^^^^^^^^^^^^^^^^^^^^ -.. currentmodule:: pytorch_lightning.plugins.checkpoint +.. currentmodule:: pytorch_lightning.plugins.io .. autosummary:: :toctree: api diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index f8e8dedaccae5..546ac66e375c1 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,6 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 call_training_type_register_plugins, TrainingTypePluginsRegistry, diff --git a/pytorch_lightning/plugins/checkpoint/__init__.py b/pytorch_lightning/plugins/io/__init__.py similarity index 75% rename from pytorch_lightning/plugins/checkpoint/__init__.py rename to pytorch_lightning/plugins/io/__init__.py index c4ec8b8c89573..e111d2266cd86 100644 --- a/pytorch_lightning/plugins/checkpoint/__init__.py +++ b/pytorch_lightning/plugins/io/__init__.py @@ -11,5 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin # noqa: F401 -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin # noqa: F401 +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin # noqa: F401 +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/checkpoint/checkpoint.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py similarity index 100% rename from pytorch_lightning/plugins/checkpoint/checkpoint.py rename to pytorch_lightning/plugins/io/checkpoint_plugin.py diff --git a/pytorch_lightning/plugins/checkpoint/torch.py b/pytorch_lightning/plugins/io/torch_plugin.py similarity index 96% rename from pytorch_lightning/plugins/checkpoint/torch.py rename to pytorch_lightning/plugins/io/torch_plugin.py index b5131d90a5357..01ead33f7cc54 100644 --- a/pytorch_lightning/plugins/checkpoint/torch.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index d4daf3dc6f9f1..ec868525c7119 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,8 +31,8 @@ from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import ( _HYDRA_AVAILABLE, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 61c442e3d881e..b2c67454ecb7d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,8 +24,8 @@ from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 853f3ff8af7e0..58ec9af47366b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -26,8 +26,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 2fe80a3ee9c3f..299aee53ec98f 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,7 +17,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 117bda7b2126d..7acf306171c83 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -16,8 +16,8 @@ import torch -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 3657487fb237d..120a07ea27bd4 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 35283e0680b62..e29724a810b7e 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,8 +22,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index df2f1dffde0cd..524c25d0c3124 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -21,8 +21,8 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index da70afb0efc6f..356760c6c97fc 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 2e9987bb542e9..9bec8f2390516 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,7 +15,7 @@ from typing import Any, Dict from pytorch_lightning.core.decorators import parameter_validation -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c72d8069cc72a..c863f265d3b69 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e1c4fe4b63281..a095087be879d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,7 +26,7 @@ from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIOPlugin from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 6d92b882bc368..d195785a98ce1 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin -from pytorch_lightning.plugins.checkpoint.torch import TorchCheckpointIOPlugin +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException From c921dbb19091282765cac27f7b99ded5040c91de Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 13 Aug 2021 12:25:40 +0100 Subject: [PATCH 31/37] Update pytorch_lightning/plugins/training_type/training_type_plugin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index a095087be879d..915cd113c724a 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -41,7 +41,7 @@ def __init__(self, checkpoint_plugin: Optional[CheckpointIOPlugin] = None) -> No self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_plugin = checkpoint_plugin if checkpoint_plugin is not None else TorchCheckpointIOPlugin() - self._checkpoint_plugin: CheckpointIOPlugin = checkpoint_plugin + self._checkpoint_plugin = checkpoint_plugin self._call_configure_sharded_model_hook = True @property From fd82276b96574f087b07e641e6c299709fd6903d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 12:26:04 +0100 Subject: [PATCH 32/37] Address reviews --- pytorch_lightning/plugins/io/checkpoint_plugin.py | 4 ++-- pytorch_lightning/plugins/io/torch_plugin.py | 4 ++-- .../trainer/connectors/accelerator_connector.py | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index cecb4b3c947ce..b866bdc032072 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -21,14 +21,14 @@ class CheckpointIOPlugin(ABC): Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIOPlugin`` but may - require particular handling depending the plugin. + require particular handling depending on the plugin. In addition, you can pass a custom ``CheckpointIOPlugin`` by extending this class and passing it to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIOPlugin()])``. .. note:: - For some plugins it is not possible to use a custom checkpoint plugin as checkpointing logic is not + For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not modifiable. """ diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 01ead33f7cc54..ce34ef0093c95 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -23,8 +23,8 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): """ - CheckpointIOPlugin that utilizes ``torch.save``/``torch.load`` to save and load checkpoints, - common for most use cases. + CheckpointIOPlugin that utilizes :func:`torch.save` and :func:`torch.load` + to save and load checkpoints respectively, common for most use cases. """ def save_checkpoint( diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 9da9e31cbc6c4..0a65e4bcb1a56 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -302,24 +302,24 @@ def handle_given_plugins(self) -> None: else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 training type plugin: {type(plug).__name__}" + "You can only specify one training type plugin." + f" Available: {type(training_type).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 precision plugin: {type(plug).__name__}" + "You can only specify one precision plugin." + f" Available: {type(precision).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, CheckpointIOPlugin): if checkpoint is None: checkpoint = plug else: raise MisconfigurationException( - "You can only specify one checkpoint plugin and one training type plugin." - f" Found more than 1 checkpoint plugin: {type(plug).__name__}" + "You can only specify one checkpoint plugin." + f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: From 2fc3558b319d53aadc7e40c824083ef90352b4da Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 12:28:27 +0100 Subject: [PATCH 33/37] Update typing --- pytorch_lightning/plugins/io/checkpoint_plugin.py | 11 +++++------ pytorch_lightning/plugins/io/torch_plugin.py | 10 ++++------ pytorch_lightning/utilities/types.py | 2 ++ tests/plugins/test_checkpoint_io_plugin.py | 10 ++++------ 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index b866bdc032072..876921bcc9066 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional + +from pytorch_lightning.utilities.types import _PATH class CheckpointIOPlugin(ABC): @@ -34,9 +35,7 @@ class CheckpointIOPlugin(ABC): """ @abstractmethod - def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -46,7 +45,7 @@ def save_checkpoint( """ @abstractmethod - def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: """ Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index ce34ef0093c95..7f6ed15fc1815 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.types import _PATH class TorchCheckpointIOPlugin(CheckpointIOPlugin): @@ -27,9 +27,7 @@ class TorchCheckpointIOPlugin(CheckpointIOPlugin): to save and load checkpoints respectively, common for most use cases. """ - def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) @@ -42,7 +40,7 @@ def save_checkpoint( atomic_save(checkpoint, path) def load_checkpoint( - self, path: Union[str, Path], map_location: Optional[Callable] = lambda storage, loc: storage + self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: """ Loads checkpoint using torch.load, with additional handling for fsspec remote loading of files. diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 774f44ceaeeec..69cac5edf784e 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,6 +17,7 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ from numbers import Number +from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union import torch @@ -31,6 +32,7 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] +_PATH = Union[str, Path] TRAIN_DATALOADERS = Union[ DataLoader, Sequence[DataLoader], diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 9e289b857319a..a605d07e61d20 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from unittest.mock import MagicMock import pytest @@ -22,6 +21,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import CheckpointIOPlugin, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PATH from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -30,13 +30,11 @@ class CustomCheckpointPlugin(CheckpointIOPlugin): save_checkpoint_called: bool = False load_checkpoint_file_called: bool = False - def save_checkpoint( - self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: self.save_checkpoint_called = True torch.save(checkpoint, path) - def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: self.load_checkpoint_file_called = True return torch.load(path) From 21783f6abf1365149e2f6e4b555ceecc6d7efc07 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 14:47:03 +0100 Subject: [PATCH 34/37] Refactor name --- CHANGELOG.md | 2 +- docs/source/advanced/checkpoint_io.rst | 10 +++++----- docs/source/api_references.rst | 4 ++-- pytorch_lightning/plugins/__init__.py | 8 ++++---- pytorch_lightning/plugins/io/__init__.py | 4 ++-- pytorch_lightning/plugins/io/checkpoint_plugin.py | 8 ++++---- pytorch_lightning/plugins/io/torch_plugin.py | 6 +++--- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- pytorch_lightning/plugins/training_type/deepspeed.py | 6 +++--- pytorch_lightning/plugins/training_type/dp.py | 4 ++-- .../plugins/training_type/fully_sharded.py | 4 ++-- pytorch_lightning/plugins/training_type/horovod.py | 4 ++-- pytorch_lightning/plugins/training_type/ipu.py | 4 ++-- pytorch_lightning/plugins/training_type/parallel.py | 4 ++-- .../plugins/training_type/single_device.py | 4 ++-- .../plugins/training_type/single_tpu.py | 6 +++--- pytorch_lightning/plugins/training_type/tpu_spawn.py | 6 +++--- .../plugins/training_type/training_type_plugin.py | 12 ++++++------ .../trainer/connectors/accelerator_connector.py | 6 +++--- tests/accelerators/test_cpu.py | 4 ++-- tests/plugins/test_checkpoint_io_plugin.py | 10 +++++----- 22 files changed, 62 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f499f8c04f3c8..2949cd8648f4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) -- Added `CheckpointIOPlugin` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) +- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) ### Changed diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst index 9975878eff780..570692dbc5cc5 100644 --- a/docs/source/advanced/checkpoint_io.rst +++ b/docs/source/advanced/checkpoint_io.rst @@ -3,10 +3,10 @@ Custom Checkpointing IO .. warning:: The Checkpoint IO API is experimental and subject to change. -Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIOPlugin``. This encapsulates the save/load logic +Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic that is managed by the ``TrainingTypePlugin``. -``CheckpointIOPlugin`` can be extended to include your custom save/load functionality to and from a path, with the object being passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. +``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path, with the object being passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. .. code-block:: python @@ -15,10 +15,10 @@ that is managed by the ``TrainingTypePlugin``. from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint - from pytorch_lightning.plugins import CheckpointIOPlugin, SingleDevicePlugin + from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin - class CustomCheckpointPlugin(CheckpointIOPlugin): + class CustomCheckpointPlugin(CheckpointIO): def save_checkpoint( self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: @@ -49,4 +49,4 @@ that is managed by the ``TrainingTypePlugin``. .. note:: - Some ``TrainingTypePlugins`` do not support custom ``CheckpointIOPlugin`` as as checkpointing logic is not modifiable. + Some ``TrainingTypePlugins`` do not support custom ``CheckpointIO`` as as checkpointing logic is not modifiable. diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 91dda78bf3418..49b5556d7a922 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -137,8 +137,8 @@ Checkpoint IO Plugins :nosignatures: :template: classtemplate.rst - CheckpointIOPlugin - TorchCheckpointIOPlugin + CheckpointIO + TorchCheckpointIO Profiler API ------------ diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 546ac66e375c1..a69065fa74f73 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,6 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin -from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 call_training_type_register_plugins, TrainingTypePluginsRegistry, @@ -31,8 +31,8 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin __all__ = [ - "CheckpointIOPlugin", - "TorchCheckpointIOPlugin", + "CheckpointIO", + "TorchCheckpointIO", "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", diff --git a/pytorch_lightning/plugins/io/__init__.py b/pytorch_lightning/plugins/io/__init__.py index e111d2266cd86..232f582c1a520 100644 --- a/pytorch_lightning/plugins/io/__init__.py +++ b/pytorch_lightning/plugins/io/__init__.py @@ -11,5 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin # noqa: F401 -from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin # noqa: F401 +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index 876921bcc9066..575399af48df3 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -17,15 +17,15 @@ from pytorch_lightning.utilities.types import _PATH -class CheckpointIOPlugin(ABC): +class CheckpointIO(ABC): """ Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. - Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIOPlugin`` but may + Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may require particular handling depending on the plugin. - In addition, you can pass a custom ``CheckpointIOPlugin`` by extending this class and passing it - to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIOPlugin()])``. + In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it + to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``. .. note:: diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 7f6ed15fc1815..7d7200317ddf8 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -14,16 +14,16 @@ from typing import Any, Callable, Dict, Optional import pytorch_lightning as pl -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.types import _PATH -class TorchCheckpointIOPlugin(CheckpointIOPlugin): +class TorchCheckpointIO(CheckpointIO): """ - CheckpointIOPlugin that utilizes :func:`torch.save` and :func:`torch.load` + CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, common for most use cases. """ diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index ec868525c7119..133fa5c808502 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -32,7 +32,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import ( _HYDRA_AVAILABLE, @@ -77,7 +77,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b2c67454ecb7d..af10d9e7254d6 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -25,7 +25,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -64,7 +64,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 58ec9af47366b..43e10022f4567 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -27,7 +27,7 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn @@ -825,9 +825,9 @@ def register_plugins(cls, plugin_registry: Dict) -> None: ) @property - def checkpoint_plugin(self) -> CheckpointIOPlugin: + def checkpoint_plugin(self) -> CheckpointIO: return self._checkpoint_plugin @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + def checkpoint_plugin(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 299aee53ec98f..6ed779a7fabc0 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,7 +17,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -33,7 +33,7 @@ class DataParallelPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]], - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, ): super().__init__( parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 7acf306171c83..0fa008c42197c 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -17,7 +17,7 @@ import torch from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -41,7 +41,7 @@ def __init__( state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, ): """ Plugin for Fully Sharded Data Parallel provided by FairScale. diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 120a07ea27bd4..997b89008cff2 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -37,7 +37,7 @@ class HorovodPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, ): super().__init__( parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index e29724a810b7e..96543ff88b6b9 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -23,7 +23,7 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader @@ -68,7 +68,7 @@ def __init__( autoreport_dir: Optional[str] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 524c25d0c3124..691627105b62c 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -22,7 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -35,7 +35,7 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, ): super().__init__(checkpoint_plugin) self.parallel_devices = parallel_devices diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 356760c6c97fc..359d922516c9c 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -26,7 +26,7 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__( self, device: torch.device, - checkpoint_plugin: Optional[CheckpointIOPlugin] = None, + checkpoint_plugin: Optional[CheckpointIO] = None, ): super().__init__(checkpoint_plugin) self.device: torch.device = device diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 9bec8f2390516..592dbe303a89f 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,7 +15,7 @@ from typing import Any, Dict from pytorch_lightning.core.decorators import parameter_validation -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -82,9 +82,9 @@ def teardown(self) -> None: os.environ.pop("PT_XLA_DEBUG", None) @property - def checkpoint_plugin(self) -> CheckpointIOPlugin: + def checkpoint_plugin(self) -> CheckpointIO: return self._checkpoint_plugin @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + def checkpoint_plugin(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c863f265d3b69..4152503cc46d4 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,7 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn @@ -348,9 +348,9 @@ def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("tpu_spawn_debug", cls, description="TPUSpawn Plugin with `debug` as True", debug=True) @property - def checkpoint_plugin(self) -> CheckpointIOPlugin: + def checkpoint_plugin(self) -> CheckpointIO: return self._checkpoint_plugin @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + def checkpoint_plugin(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 915cd113c724a..fce3cbc226bcf 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -24,9 +24,9 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module -from pytorch_lightning.plugins import TorchCheckpointIOPlugin +from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIOPlugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -37,19 +37,19 @@ class TrainingTypePlugin(Plugin, ABC): Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ - def __init__(self, checkpoint_plugin: Optional[CheckpointIOPlugin] = None) -> None: + def __init__(self, checkpoint_plugin: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None - checkpoint_plugin = checkpoint_plugin if checkpoint_plugin is not None else TorchCheckpointIOPlugin() + checkpoint_plugin = checkpoint_plugin if checkpoint_plugin is not None else TorchCheckpointIO() self._checkpoint_plugin = checkpoint_plugin self._call_configure_sharded_model_hook = True @property - def checkpoint_plugin(self) -> CheckpointIOPlugin: + def checkpoint_plugin(self) -> CheckpointIO: return self._checkpoint_plugin @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIOPlugin) -> None: + def checkpoint_plugin(self, plugin: CheckpointIO) -> None: self._checkpoint_plugin = plugin def connect(self, model: Module) -> None: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0a65e4bcb1a56..54ee79078cd98 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -26,7 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, - CheckpointIOPlugin, + CheckpointIO, DataParallelPlugin, DDP2Plugin, DDPFullyShardedPlugin, @@ -135,7 +135,7 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None - self._checkpoint_plugin: Optional[CheckpointIOPlugin] = None + self._checkpoint_plugin: Optional[CheckpointIO] = None plugins = plugins if plugins is not None else [] @@ -313,7 +313,7 @@ def handle_given_plugins(self) -> None: "You can only specify one precision plugin." f" Available: {type(precision).__name__}, given: {type(plug).__name__}" ) - elif isinstance(plug, CheckpointIOPlugin): + elif isinstance(plug, CheckpointIO): if checkpoint is None: checkpoint = plug else: diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index d195785a98ce1..35ed6765ffbff 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin -from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIOPlugin +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -202,7 +202,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointIOPlugin()) + plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index a605d07e61d20..e45d566d98650 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -19,14 +19,14 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.plugins import CheckpointIOPlugin, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin +from pytorch_lightning.plugins import CheckpointIO, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PATH from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf -class CustomCheckpointPlugin(CheckpointIOPlugin): +class CustomCheckpointIO(CheckpointIO): save_checkpoint_called: bool = False load_checkpoint_file_called: bool = False @@ -43,8 +43,8 @@ def test_checkpoint_plugin_called(tmpdir): """ Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading. """ - checkpoint_plugin = CustomCheckpointPlugin() - checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointPlugin) + checkpoint_plugin = CustomCheckpointIO() + checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointIO) ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) @@ -83,4 +83,4 @@ def test_checkpoint_plugin_called(tmpdir): @pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) def test_no_checkpoint_io_plugin_support(plugin_cls): with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): - plugin_cls().checkpoint_plugin = CustomCheckpointPlugin() + plugin_cls().checkpoint_plugin = CustomCheckpointIO() From 3b8c3f58df78e3359afa4d4a060277f8044db4ce Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 15:43:57 +0100 Subject: [PATCH 35/37] Clear up signature of function; checkpoint_plugin -> checkpoint_io --- docs/source/advanced/checkpoint_io.rst | 2 +- .../plugins/training_type/ddp.py | 4 ++-- .../plugins/training_type/ddp_spawn.py | 4 ++-- .../plugins/training_type/deepspeed.py | 8 ++++---- pytorch_lightning/plugins/training_type/dp.py | 6 ++---- .../plugins/training_type/fully_sharded.py | 4 ++-- .../plugins/training_type/horovod.py | 6 ++---- .../plugins/training_type/ipu.py | 4 ++-- .../plugins/training_type/parallel.py | 4 ++-- .../plugins/training_type/single_device.py | 4 ++-- .../plugins/training_type/single_tpu.py | 8 ++++---- .../plugins/training_type/tpu_spawn.py | 8 ++++---- .../training_type/training_type_plugin.py | 20 +++++++++---------- .../connectors/accelerator_connector.py | 8 ++++---- tests/accelerators/test_cpu.py | 2 +- tests/plugins/test_checkpoint_io_plugin.py | 4 ++-- 16 files changed, 46 insertions(+), 50 deletions(-) diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst index 570692dbc5cc5..a2697ccb46a68 100644 --- a/docs/source/advanced/checkpoint_io.rst +++ b/docs/source/advanced/checkpoint_io.rst @@ -42,7 +42,7 @@ that is managed by the ``TrainingTypePlugin``. model = MyModel() device = torch.device("cpu") trainer = Trainer( - plugins=SingleDevicePlugin(device, checkpoint_plugin=checkpoint_plugin), + plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), callbacks=ModelCheckpoint(save_last=True), ) trainer.fit(model) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 133fa5c808502..8348a565fa486 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -77,7 +77,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -87,7 +87,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, - checkpoint_plugin=checkpoint_plugin, + checkpoint_io=checkpoint_io, ) self.interactive_ddp_procs = [] if num_nodes is not None: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index af10d9e7254d6..759743ad4a825 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -64,7 +64,7 @@ def __init__( parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, @@ -74,7 +74,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, - checkpoint_plugin=checkpoint_plugin, + checkpoint_io=checkpoint_io, ) if num_nodes is not None: rank_zero_deprecation( diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 43e10022f4567..c0f3b10991aa8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -825,9 +825,9 @@ def register_plugins(cls, plugin_registry: Dict) -> None: ) @property - def checkpoint_plugin(self) -> CheckpointIO: - return self._checkpoint_plugin + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io - @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIO) -> None: + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 6ed779a7fabc0..551324416cce9 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -33,11 +33,9 @@ class DataParallelPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]], - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): - super().__init__( - parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin - ) + super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 0fa008c42197c..29c74439dd5ee 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -41,7 +41,7 @@ def __init__( state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): """ Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -93,7 +93,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, - checkpoint_plugin=checkpoint_plugin, + checkpoint_io=checkpoint_io, ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 997b89008cff2..e5eb8bf9723ea 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -37,11 +37,9 @@ class HorovodPlugin(ParallelPlugin): def __init__( self, parallel_devices: Optional[List[torch.device]] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): - super().__init__( - parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin - ) + super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 96543ff88b6b9..4e711ddb406eb 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -68,7 +68,7 @@ def __init__( autoreport_dir: Optional[str] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -88,7 +88,7 @@ def __init__( super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, - checkpoint_plugin=checkpoint_plugin, + checkpoint_io=checkpoint_io, ) if not _POPTORCH_AVAILABLE or not poptorch.ipuHardwareIsAvailable(): raise MisconfigurationException( diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 691627105b62c..71aae1bb71a91 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -35,9 +35,9 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): - super().__init__(checkpoint_plugin) + super().__init__(checkpoint_io) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 359d922516c9c..c92fead861c19 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -26,9 +26,9 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__( self, device: torch.device, - checkpoint_plugin: Optional[CheckpointIO] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): - super().__init__(checkpoint_plugin) + super().__init__(checkpoint_io) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 592dbe303a89f..b6f7d4000da94 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -82,9 +82,9 @@ def teardown(self) -> None: os.environ.pop("PT_XLA_DEBUG", None) @property - def checkpoint_plugin(self) -> CheckpointIO: - return self._checkpoint_plugin + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io - @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIO) -> None: + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4152503cc46d4..ee4cd9934d650 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -348,9 +348,9 @@ def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("tpu_spawn_debug", cls, description="TPUSpawn Plugin with `debug` as True", debug=True) @property - def checkpoint_plugin(self) -> CheckpointIO: - return self._checkpoint_plugin + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io - @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIO) -> None: + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index fce3cbc226bcf..09363c3dc826b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -37,20 +37,20 @@ class TrainingTypePlugin(Plugin, ABC): Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ - def __init__(self, checkpoint_plugin: Optional[CheckpointIO] = None) -> None: + def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None - checkpoint_plugin = checkpoint_plugin if checkpoint_plugin is not None else TorchCheckpointIO() - self._checkpoint_plugin = checkpoint_plugin + checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() + self._checkpoint_io = checkpoint_io self._call_configure_sharded_model_hook = True @property - def checkpoint_plugin(self) -> CheckpointIO: - return self._checkpoint_plugin + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io - @checkpoint_plugin.setter - def checkpoint_plugin(self, plugin: CheckpointIO) -> None: - self._checkpoint_plugin = plugin + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: + self._checkpoint_io = plugin def connect(self, model: Module) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -155,7 +155,7 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: return self._results def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return self.checkpoint_plugin.load_checkpoint(checkpoint_path) + return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) @@ -292,7 +292,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: # dump states as a checkpoint dictionary object checkpoint = self.on_save(checkpoint) if self.should_rank_save_checkpoint: - return self.checkpoint_plugin.save_checkpoint(checkpoint, filepath) + return self.checkpoint_io.save_checkpoint(checkpoint, filepath) @contextlib.contextmanager def model_sharded_context(self) -> Generator: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 54ee79078cd98..62f6f3619fd44 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -135,7 +135,7 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None - self._checkpoint_plugin: Optional[CheckpointIO] = None + self._checkpoint_io: Optional[CheckpointIO] = None plugins = plugins if plugins is not None else [] @@ -335,7 +335,7 @@ def handle_given_plugins(self) -> None: self._training_type_plugin = training_type self._precision_plugin = precision - self._checkpoint_plugin = checkpoint + self._checkpoint_io = checkpoint self._cluster_environment = cluster_environment or self.select_cluster_environment() @property @@ -353,8 +353,8 @@ def training_type_plugin(self) -> TrainingTypePlugin: self._training_type_plugin = self.select_training_type_plugin() self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) # attach checkpoint plugin to the training type plugin - if self._checkpoint_plugin is not None: - self._training_type_plugin.checkpoint_plugin = self._checkpoint_plugin + if self._checkpoint_io is not None: + self._training_type_plugin.checkpoint_io = self._checkpoint_io self._training_type_plugin_resolved = True return self._training_type_plugin diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 35ed6765ffbff..a4a4d6d62d5d3 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -202,7 +202,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu"), checkpoint_plugin=TorchCheckpointIO()) + plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index e45d566d98650..ef43b8b14b146 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -52,7 +52,7 @@ def test_checkpoint_plugin_called(tmpdir): device = torch.device("cpu") trainer = Trainer( default_root_dir=tmpdir, - plugins=SingleDevicePlugin(device, checkpoint_plugin=checkpoint_plugin), + plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), callbacks=ck, max_epochs=1, ) @@ -83,4 +83,4 @@ def test_checkpoint_plugin_called(tmpdir): @pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) def test_no_checkpoint_io_plugin_support(plugin_cls): with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): - plugin_cls().checkpoint_plugin = CustomCheckpointIO() + plugin_cls().checkpoint_io = CustomCheckpointIO() From 9cfe98f968e9967ba4da9556313d642915ea22af Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 15:53:09 +0100 Subject: [PATCH 36/37] Slightly cleaner --- docs/source/advanced/checkpoint_io.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst index a2697ccb46a68..d93b7adfd903b 100644 --- a/docs/source/advanced/checkpoint_io.rst +++ b/docs/source/advanced/checkpoint_io.rst @@ -6,7 +6,7 @@ Custom Checkpointing IO Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic that is managed by the ``TrainingTypePlugin``. -``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path, with the object being passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. +``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. .. code-block:: python From 8f234e0dd0c743c028148ab1eedf98d62403c1ef Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 13 Aug 2021 16:33:13 +0100 Subject: [PATCH 37/37] Address reviews --- docs/source/advanced/checkpoint_io.rst | 8 ++++---- pytorch_lightning/plugins/io/torch_plugin.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst index d93b7adfd903b..6eabfae99b07b 100644 --- a/docs/source/advanced/checkpoint_io.rst +++ b/docs/source/advanced/checkpoint_io.rst @@ -18,7 +18,7 @@ that is managed by the ``TrainingTypePlugin``. from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin - class CustomCheckpointPlugin(CheckpointIO): + class CustomCheckpointIO(CheckpointIO): def save_checkpoint( self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None ) -> None: @@ -28,12 +28,12 @@ that is managed by the ``TrainingTypePlugin``. ... - checkpoint_plugin = CustomCheckpointPlugin() + custom_checkpoint_io = CustomCheckpointIO() # Pass into the Trainer object model = MyModel() trainer = Trainer( - plugins=[checkpoint_plugin], + plugins=[custom_checkpoint_io], callbacks=ModelCheckpoint(save_last=True), ) trainer.fit(model) @@ -42,7 +42,7 @@ that is managed by the ``TrainingTypePlugin``. model = MyModel() device = torch.device("cpu") trainer = Trainer( - plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), + plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io), callbacks=ModelCheckpoint(save_last=True), ) trainer.fit(model) diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 7d7200317ddf8..e95f3d3b226f7 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -43,7 +43,7 @@ def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: """ - Loads checkpoint using torch.load, with additional handling for fsspec remote loading of files. + Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: path: Path to checkpoint