-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Introduce CheckpointIO Plugin #8743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
a93e452
poc API
72f4dfd
Merge branch 'master' into feat/ckpt_plugin
e7d2b66
Fix up the API, unsure on connection
b41e794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9161980
Example API
7aa4e8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] dffe088
Update all constructors
cacf0e5
Move towards having the checkpoint plugin not require the plugin, and…
028ac38
Remove import
99c7a46
Fix tests
3adc486
Change name
b7d5b55
Cleanups
0a0a068
Fixes/Cleanups
97fb2a2
Use property
402156e
Fixes to signature
5310a7f
Merge branch 'master' into feat/ckpt_plugin
d7f567a
Add warning for TPU plugins that they do not support custom checkpoin…
fcc24b4
Cleanup API, introduce storage options
4421276
Update signature to be more general
d84cce1
Address feedback, add test for support check
38c22a2
Merge branch 'master' into feat/ckpt_plugin
b7f37ee
Add CHANGELOG.md
49086cc
fix tests
936f65a
change name
049a676
Fix mypy
1ff0912
Reviews
1841d3b
Add ability to pass checkpoint plugin through the trainer
b909dfe
Add constraints
50b11b5
Match signature to see if mypy works
9e16e34
Address review points
642e6fa
Revert changes to typing
5c9e973
Add docs/doc strings and API
6361c87
Address feedback
c921dbb
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
fd82276
Address reviews
2fc3558
Update typing
21783f6
Refactor name
3b8c3f5
Clear up signature of function; checkpoint_plugin -> checkpoint_io
9cfe98f
Slightly cleaner
8f234e0
Address reviews
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from abc import ABC | ||
from pathlib import Path | ||
from typing import Any, Dict, Mapping, Union | ||
|
||
from torch.nn import Module | ||
|
||
from pytorch_lightning import LightningModule | ||
|
||
|
||
class CheckpointPlugin(ABC): | ||
def __init__(self): | ||
self._training_type_plugin = None | ||
|
||
@property | ||
def training_type_plugin(self) -> "TrainingTypePlugin": | ||
return self._training_type_plugin | ||
|
||
@training_type_plugin.setter | ||
def training_type_plugin(self, plugin) -> None: | ||
self._training_type_plugin = plugin | ||
|
||
@property | ||
def lightning_module(self) -> LightningModule: | ||
return self.training_type_plugin.lightning_module | ||
|
||
@property | ||
def model(self) -> Module: | ||
return self.training_type_plugin.model | ||
|
||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Save model/training states as a checkpoint file through state-dump and file-write. | ||
|
||
Args: | ||
checkpoint: dict containing model and trainer state | ||
filepath: write-target file's path | ||
""" | ||
|
||
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. | ||
Args: | ||
checkpoint_path: Path to checkpoint | ||
|
||
Returns: The loaded checkpoint. | ||
""" | ||
pass | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
""" | ||
Given the loaded checkpoint file, loads the state dict into the model. | ||
Args: | ||
checkpoint: The loaded checkpoint file. | ||
""" | ||
|
||
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
""" | ||
Given the loaded checkpoint file, loads the optimizer state dicts into optimizers. | ||
Args: | ||
checkpoint: The loaded checkpoint file. | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from pathlib import Path | ||
from typing import Any, Dict, Mapping, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE | ||
from pytorch_lightning.utilities.warnings import WarningCache | ||
|
||
warning_cache = WarningCache() | ||
if _DEEPSPEED_AVAILABLE: | ||
import deepspeed | ||
|
||
|
||
class DeepSpeedCheckpointPlugin(CheckpointPlugin): | ||
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: | ||
"""Save model/training states as a checkpoint file through state-dump and file-write. | ||
|
||
Args: | ||
checkpoint: The checkpoint state dictionary | ||
filepath: write-target file's path | ||
""" | ||
if ( | ||
self.training_type_plugin.zero_stage_3 | ||
and self.training_type_plugin._multi_device | ||
and self.training_type_plugin.is_global_zero | ||
): | ||
warning_cache.warn( | ||
"When saving the DeepSpeed Stage 3 checkpoint, " | ||
"each worker will save a shard of the checkpoint within a directory. " | ||
"If a single file is required after training, " | ||
"see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" | ||
"deepspeed-zero-stage-3-single-file for instructions." | ||
) | ||
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes | ||
# dump states as a checkpoint dictionary object | ||
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] | ||
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} | ||
self.model.save_checkpoint(filepath, client_state=checkpoint) | ||
|
||
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: | ||
if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: | ||
# Broadcast to ensure we load from the rank 0 checkpoint | ||
# This doesn't have to be the case when using deepspeed sharded checkpointing | ||
checkpoint_path = self.training_type_plugin.broadcast(checkpoint_path) | ||
return super().load_checkpoint_file(checkpoint_path) | ||
|
||
# Rely on deepspeed to load the checkpoint and necessary information | ||
from pytorch_lightning.trainer.states import TrainerFn | ||
|
||
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING | ||
_, client_state = self.model.load_checkpoint( | ||
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting | ||
) | ||
if client_state is None: | ||
raise MisconfigurationException( | ||
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint " | ||
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`." | ||
) | ||
return client_state | ||
|
||
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` | ||
if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3: | ||
self.training_type_plugin.model_to_device() | ||
self._restore_zero_state(checkpoint) | ||
|
||
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` | ||
pass | ||
|
||
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | ||
"""Returns model state.""" | ||
model = self.lightning_module | ||
return model.state_dict() | ||
|
||
def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: | ||
""" | ||
Overrides the normal load_state_dict behaviour in PyTorch to ensure | ||
we gather parameters that may be sharded across processes before loading | ||
the state dictionary when using ZeRO stage 3. | ||
This is then automatically synced across processes. | ||
|
||
Args: | ||
ckpt: The ckpt file. | ||
""" | ||
|
||
def load(module: torch.nn.Module, prefix=""): | ||
|
||
missing_keys = [] | ||
unexpected_keys = [] | ||
error_msgs = [] | ||
state_dict = ckpt["state_dict"] | ||
|
||
# copy state_dict so _load_from_state_dict can modify it | ||
metadata = getattr(state_dict, "_metadata", None) | ||
state_dict = state_dict.copy() | ||
if metadata is not None: | ||
state_dict._metadata = metadata | ||
|
||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | ||
# because zero3 puts placeholders in model params, this context | ||
# manager gathers (unpartitions) the params of the current layer, then loads from | ||
# the state dict and then re-partitions them again | ||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): | ||
if self.training_type_plugin.is_global_zero: | ||
module._load_from_state_dict( | ||
state_dict=state_dict, | ||
prefix=prefix, | ||
local_metadata=local_metadata, | ||
strict=True, | ||
missing_keys=missing_keys, | ||
unexpected_keys=unexpected_keys, | ||
error_msgs=error_msgs, | ||
) | ||
|
||
for name, child in module._modules.items(): | ||
if child is not None: | ||
load(child, prefix + name + ".") | ||
|
||
load(self.lightning_module, prefix="") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pathlib import Path | ||
from typing import Any, Dict, Mapping, Union | ||
|
||
from torch import Tensor | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.utilities.cloud_io import atomic_save | ||
from pytorch_lightning.utilities.cloud_io import load as pl_load | ||
|
||
|
||
class TorchCheckpointPlugin(CheckpointPlugin): | ||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: | ||
# dump states as a checkpoint dictionary object | ||
try: | ||
# write the checkpoint dictionary on the file | ||
atomic_save(checkpoint, filepath) | ||
except AttributeError as err: | ||
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
checkpoint.pop(key, None) | ||
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") | ||
atomic_save(checkpoint, filepath) | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | ||
return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
self.lightning_module.load_state_dict(checkpoint["state_dict"]) | ||
|
||
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
optimizer_states = checkpoint["optimizer_states"] | ||
for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): | ||
optimizer.load_state_dict(opt_state) | ||
|
||
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | ||
"""Returns model state.""" | ||
model = self.lightning_module | ||
return model.state_dict() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.