-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515
Changes from all commits
9cb0246
ee7f979
e86670e
20e429a
b352d24
815c474
6791642
ec70b80
f01e705
7f59061
741bd63
8138319
7482b8c
3f47522
9e4d746
c2e1afb
ef2be98
d4f2e15
a9dcd13
15ed02f
6388a33
7265fda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,20 +12,25 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
from typing import Dict, Generator, List, Optional | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict, Generator, List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, rank_zero_info | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
from fairscale.nn import default_auto_wrap_policy, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class DDPFullyShardedPlugin(DDPPlugin): | ||
def __init__( | ||
|
@@ -174,6 +179,42 @@ def model_to_device(self) -> None: | |
# ensure we update the device type in the lightning module | ||
self.lightning_module.to(self.root_device) | ||
|
||
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | ||
# Currently it is same as default TrainingTypePlugin, i.e. return | ||
# the full state dict for FSDP, in the future, we will provide sharded | ||
# state dict. | ||
return super().lightning_module_state_dict() | ||
|
||
def load_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | ||
""" | ||
Use for each rank to reload the model weights. | ||
|
||
Args: | ||
checkpoint_path: Path to the current checkpoint. | ||
|
||
Returns: | ||
checkpoint: Current checkpoint | ||
""" | ||
checkpoint = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring ? |
||
rank_zero_info( | ||
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model " | ||
"state restoration to avoid CPU OOMs." | ||
) | ||
# Each rank will load the current checkpoint from `checkpoint_path` | ||
# and load the weights while the others are waiting for this operation to complete | ||
# with a barrier. | ||
for current_worker in range(self.num_processes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this break if the number of gpus on checkpointing / saving isn't the same. Should we save the world_size + rank in each checkpoint and re-use that information on reload ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for raising this! thinking aloud — in the same training session we wouldn't run into this problem, but even in the case where a model checkpoint from a different training session is being finetuned in a new training session with a different machine, do we envision that world size would be different (for our use cases at least we have config files that keep constant variables like Trainer(gpus) for the same model training)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that is definitely something we need to support. For example I usually use the output from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW i ran a test taking a checkpoint from a previous training session and starting a fresh one: setting trainer.resume_from_checkpoint to the old checkpoint and trainer.gpus=4 instead of 8 (which was num gpus on the original training session) and loading the checkpoint on 4 GPUs didn't break. i can add this feature in another pull request! would a good place to add the metadata be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply! I've filed #9123 to track this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mind adding a comment to explain what is happening in case new reader reached this code :) |
||
if self.local_rank == current_worker: | ||
checkpoint = self.load_checkpoint_file(checkpoint_path) | ||
self.on_load_checkpoint(checkpoint) | ||
self.load_model_state_dict(checkpoint.pop("state_dict")) | ||
log.info( | ||
f"Rank {self.global_rank}: done loading model states from {checkpoint_path}, " | ||
"deleted state_dict from checkpoint." | ||
) | ||
self.barrier() | ||
return checkpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess, returning the checkpoint is not required. + docstrings |
||
|
||
@property | ||
def setup_optimizers_in_pre_dispatch(self) -> bool: | ||
# Setup optimizers after the Fully Sharded Model has been made | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,6 +160,15 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | |
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
self.lightning_module.load_state_dict(checkpoint["state_dict"]) | ||
|
||
def on_load_checkpoint(self, checkpoint: Mapping[str, Any]) -> None: | ||
self.lightning_module.on_load_checkpoint(checkpoint) | ||
|
||
def load_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | ||
checkpoint = self.load_checkpoint_file(checkpoint_path) | ||
self.on_load_checkpoint(checkpoint) | ||
self.load_model_state_dict(checkpoint) | ||
return checkpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here. |
||
|
||
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
optimizer_states = checkpoint["optimizer_states"] | ||
for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add
num_nodes
too.