Skip to content

Commit c5c0a9a

Browse files
LucasLLCsvekars
andauthored
Updates dcp tutorial with recent updates to api including save, load, and distributed state dict (#2832)
* updates dcp tutorial with recent updates to api including save, load Co-authored-by: Svetlana Karslioglu <[email protected]> --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 700b9d8 commit c5c0a9a

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

recipes_source/distributed_checkpoint_recipe.rst

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Getting Started with Distributed Checkpoint (DCP)
22
=====================================================
33

4-
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__
4+
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__, `Lucas Pasqualin <https://github.com/lucasllc>`__
55

66
.. note::
77
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__.
@@ -22,8 +22,12 @@ In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
2222
How DCP works
2323
--------------
2424

25-
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel.
26-
In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies.
25+
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel,
26+
and then re-shard across differing cluster topologies at load time.
27+
28+
Addditionally, through the use of modules in :func:`torch.distributed.checkpoint.state_dict`,
29+
DCP offers support for gracefully handling ``state_dict`` generation and loading in distributed settings.
30+
This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
2731

2832
DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways:
2933

@@ -42,19 +46,20 @@ Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly,
4246
Saving
4347
~~~~~~
4448

45-
Now, lets create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
49+
Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
4650

4751
.. code-block:: python
4852
4953
import os
5054
5155
import torch
5256
import torch.distributed as dist
53-
import torch.distributed.checkpoint as DCP
57+
import torch.distributed.checkpoint as dcp
5458
import torch.multiprocessing as mp
5559
import torch.nn as nn
5660
5761
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
62+
from torch.distributed.checkpoint.state_dict import get_state_dict
5863
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
5964
6065
CHECKPOINT_DIR = "checkpoint"
@@ -99,20 +104,14 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp
99104
model(torch.rand(8, 16, device="cuda")).sum().backward()
100105
optimizer.step()
101106
102-
# set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict
103-
# note that we do not support FSDP StateDictType.LOCAL_STATE_DICT
104-
FSDP.set_state_dict_type(
105-
model,
106-
StateDictType.SHARDED_STATE_DICT,
107-
)
107+
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
108+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
108109
state_dict = {
109-
"model": model.state_dict(),
110+
"model": model_state_dict,
111+
"optimizer": optimizer_state_dict
110112
}
113+
dcp.save(state_dict,checkpoint_id=CHECKPOINT_DIR)
111114
112-
DCP.save_state_dict(
113-
state_dict=state_dict,
114-
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
115-
)
116115
117116
cleanup()
118117
@@ -152,12 +151,12 @@ The reason that we need the ``state_dict`` prior to loading is:
152151
153152
import torch
154153
import torch.distributed as dist
155-
import torch.distributed.checkpoint as DCP
154+
import torch.distributed.checkpoint as dcp
155+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
156156
import torch.multiprocessing as mp
157157
import torch.nn as nn
158158
159159
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
160-
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
161160
162161
CHECKPOINT_DIR = "checkpoint"
163162
@@ -194,21 +193,23 @@ The reason that we need the ``state_dict`` prior to loading is:
194193
model = ToyModel().to(rank)
195194
model = FSDP(model)
196195
197-
FSDP.set_state_dict_type(
198-
model,
199-
StateDictType.SHARDED_STATE_DICT,
200-
)
201-
# different from ``torch.load()``, DCP requires model state_dict prior to loading to get
202-
# the allocated storage and sharding information.
196+
# generates the state dict we will load into
197+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
203198
state_dict = {
204-
"model": model.state_dict(),
199+
"model": model_state_dict,
200+
"optimizer": optimizer_state_dict
205201
}
206-
207-
DCP.load_state_dict(
202+
dcp.load(
208203
state_dict=state_dict,
209-
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
204+
checkpoint_id=CHECKPOINT_DIR,
205+
)
206+
# sets our state dicts on the model and optimizer, now that we've loaded
207+
set_state_dict(
208+
model,
209+
optimizer,
210+
model_state_dict=model_state_dict,
211+
optim_state_dict=optimizer_state_dict
210212
)
211-
model.load_state_dict(state_dict["model"])
212213
213214
cleanup()
214215
@@ -224,7 +225,8 @@ The reason that we need the ``state_dict`` prior to loading is:
224225
)
225226
226227
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP.
227-
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP.
228+
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers
229+
the intent is to save or load in "non-distributed" style, meaning entirely in the current process.
228230

229231
.. note::
230232
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
@@ -259,11 +261,10 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
259261
"model": model.state_dict(),
260262
}
261263
262-
# turn no_dist to be true to load in non-distributed setting
263-
DCP.load_state_dict(
264+
# since no progress group is initialized, DCP will disable any collectives.
265+
dcp.load(
264266
state_dict=state_dict,
265-
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
266-
no_dist=True,
267+
checkpoint_id=CHECKPOINT_DIR,
267268
)
268269
model.load_state_dict(state_dict["model"])
269270
@@ -274,7 +275,9 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
274275
275276
Conclusion
276277
----------
277-
In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
278+
In conclusion, we have learned how to use DCP's :func:`save` and :func:`load` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
279+
Additionally, we've learned how to use :func:`get_state_dict` and :func:`set_state_dict` to automatically manage parallelism-specific FQN's and defaults during state dict
280+
generation and loading.
278281

279282
For more information, please see the following:
280283

0 commit comments

Comments
 (0)