Avoid CPU OOM when loading full-state FSDP checkpoints in Fabric #18138
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #18008
Related: #8043
Resolves a comment in the code regarding the loading of full-state checkpoints into a FSDP model: Before this PR, each worker loads its own copy of the checkpoint into CPU memory. Example: On a machine with 8 GPUs and a checkpoint of size 10 GB, we would occupy 8 * 10 = 80 GB of CPU memory when loading the state dict. If there is not enough CPU RAM, we get OOM.
This PR implements a strategy in which we load the state dict sequentially among the local ranks. The drawback is that loading will take N times longer, where N is the number of GPUs on the machine.
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist