You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__.
@@ -22,8 +22,12 @@ In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
22
22
How DCP works
23
23
--------------
24
24
25
-
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel.
26
-
In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies.
25
+
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel,
26
+
and then re-shard across differing cluster topologies at load time.
27
+
28
+
Addditionally, through the use of modules in :func:`torch.distributed.checkpoint.state_dict`,
29
+
DCP offers support for gracefully handling ``state_dict`` generation and loading in distributed settings.
30
+
This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
27
31
28
32
DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways:
29
33
@@ -42,19 +46,20 @@ Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly,
42
46
Saving
43
47
~~~~~~
44
48
45
-
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
49
+
Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
46
50
47
51
.. code-block:: python
48
52
49
53
import os
50
54
51
55
import torch
52
56
import torch.distributed as dist
53
-
import torch.distributed.checkpoint asDCP
57
+
import torch.distributed.checkpoint asdcp
54
58
import torch.multiprocessing as mp
55
59
import torch.nn as nn
56
60
57
61
from torch.distributed.fsdp import FullyShardedDataParallel asFSDP
62
+
from torch.distributed.checkpoint.state_dict import get_state_dict
58
63
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
59
64
60
65
CHECKPOINT_DIR="checkpoint"
@@ -99,20 +104,14 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp
# sets our state dicts on the model and optimizer, now that we've loaded
207
+
set_state_dict(
208
+
model,
209
+
optimizer,
210
+
model_state_dict=model_state_dict,
211
+
optim_state_dict=optimizer_state_dict
210
212
)
211
-
model.load_state_dict(state_dict["model"])
212
213
213
214
cleanup()
214
215
@@ -224,7 +225,8 @@ The reason that we need the ``state_dict`` prior to loading is:
224
225
)
225
226
226
227
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP.
227
-
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP.
228
+
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers
229
+
the intent is to save or load in "non-distributed" style, meaning entirely in the current process.
228
230
229
231
.. note::
230
232
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
@@ -259,11 +261,10 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
259
261
"model": model.state_dict(),
260
262
}
261
263
262
-
#turn no_dist to be true to load in non-distributed setting
263
-
DCP.load_state_dict(
264
+
#since no progress group is initialized, DCP will disable any collectives.
@@ -274,7 +275,9 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M
274
275
275
276
Conclusion
276
277
----------
277
-
In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
278
+
In conclusion, we have learned how to use DCP's :func:`save` and :func:`load` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
279
+
Additionally, we've learned how to use :func:`get_state_dict` and :func:`set_state_dict` to automatically manage parallelism-specific FQN's and defaults during state dict
0 commit comments