-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515
Changes from 13 commits
9cb0246
ee7f979
e86670e
20e429a
b352d24
815c474
6791642
ec70b80
f01e705
7f59061
741bd63
8138319
7482b8c
3f47522
9e4d746
c2e1afb
ef2be98
d4f2e15
a9dcd13
15ed02f
6388a33
7265fda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,8 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict, Generator, List, Optional, Union | ||
|
||
import torch | ||
|
@@ -27,6 +29,8 @@ | |
from fairscale.nn import default_auto_wrap_policy, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
log: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class DDPFullyShardedPlugin(DDPPlugin): | ||
|
||
|
@@ -178,6 +182,19 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | |
# state dict. | ||
return super().lightning_module_state_dict() | ||
|
||
def serialized_restore_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | ||
checkpoint = {} | ||
log.info(f"FullyShardedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.") | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for current_worker in range(self.num_processes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this break if the number of gpus on checkpointing / saving isn't the same. Should we save the world_size + rank in each checkpoint and re-use that information on reload ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for raising this! thinking aloud — in the same training session we wouldn't run into this problem, but even in the case where a model checkpoint from a different training session is being finetuned in a new training session with a different machine, do we envision that world size would be different (for our use cases at least we have config files that keep constant variables like Trainer(gpus) for the same model training)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that is definitely something we need to support. For example I usually use the output from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW i ran a test taking a checkpoint from a previous training session and starting a fresh one: setting trainer.resume_from_checkpoint to the old checkpoint and trainer.gpus=4 instead of 8 (which was num gpus on the original training session) and loading the checkpoint on 4 GPUs didn't break. i can add this feature in another pull request! would a good place to add the metadata be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply! I've filed #9123 to track this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mind adding a comment to explain what is happening in case new reader reached this code :) |
||
if self.local_rank == current_worker: | ||
checkpoint = super().load_checkpoint_file(checkpoint_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this assume the same checkpoint path for all workers ? |
||
self.lightning_module.on_load_checkpoint(checkpoint) | ||
self.load_model_state_dict(checkpoint) | ||
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") | ||
del checkpoint["state_dict"] | ||
self.barrier() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is pretty smart :) one concern, now the responsibility of calling the hooks is shifted to the plugin. do we want to allow that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, I don't see a problem in it. The only thing I am afraid of, is that we may have a precedence case here and that might lead to this pattern where we don't want it. What do you think @SeanNaren @awaelchli @ananthsub @carmocca ? |
||
return checkpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess, returning the checkpoint is not required. + docstrings |
||
|
||
@property | ||
def setup_optimizers_in_pre_dispatch(self) -> bool: | ||
# Setup optimizers after the Fully Sharded Model has been made | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring ?