Skip to content

Commit 4c27d25

Browse files
Fsdp code fixes (#1867)
* fix the code snippets * fix the blog link * fixes the bullet points
1 parent 62ff2fd commit 4c27d25

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

intermediate_source/FSDP_tutorial.rst

+28-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ It also comes with considerable engineering complexity to handle the training of
99
`Pytorch FSDP <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__, released in PyTorch 1.11 makes this easier.
1010

1111
In this tutorial, we show how to use `FSDP APIs <https://pytorch.org/docs/1.11/fsdp.html>`__, for simple MNIST models that can be extended to other larger models such as `HuggingFace BERT models <https://huggingface.co/blog/zero-deepspeed-fairscale>`__,
12-
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d>`__ . The sample DDP MNIST code has been borrowed from `here <https://github.com/yqhu/mnist_examples>`__.
12+
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ . The sample DDP MNIST code has been borrowed from `here <https://github.com/yqhu/mnist_examples>`__.
1313

1414

1515
How FSDP works
@@ -28,18 +28,21 @@ FSDP GPU memory footprint would be smaller than DDP across all workers. This mak
2828
At high level FDSP works as follow:
2929

3030
*In constructor*
31-
Shard model parameters and each rank only keeps its own shard
31+
32+
* Shard model parameters and each rank only keeps its own shard
3233

3334
*In forward path*
34-
Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
35-
Run forward computation
36-
Discard parameter shards it has just collected
35+
36+
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
37+
* Run forward computation
38+
* Discard parameter shards it has just collected
3739

3840
*In backward path*
39-
Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
40-
Run backward computation
41-
Run reduce_scatter to sync gradients
42-
Discard parameters.
41+
42+
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
43+
* Run backward computation
44+
* Run reduce_scatter to sync gradients
45+
* Discard parameters.
4346

4447
How to use FSDP
4548
--------------
@@ -49,7 +52,8 @@ Here we use a toy model to run training on MNIST dataset for demonstration purpo
4952

5053
1.1 Install Pytorch along with Torchvision
5154

52-
.. code-block::
55+
.. code-block:: bash
56+
5357
pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
5458
5559
We add the following code snippets to a python script “FSDP_mnist.py”.
@@ -133,6 +137,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
133137
2.2 define a train function
134138

135139
.. code-block:: python
140+
136141
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
137142
model.train()
138143
ddp_loss = torch.zeros(2).to(rank)
@@ -155,6 +160,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
155160
2.3 Define a validation function
156161

157162
.. code-block:: python
163+
158164
def test(model, rank, world_size, test_loader):
159165
model.eval()
160166
correct = 0
@@ -284,15 +290,17 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
284290
285291
We have recorded cuda events to measure the time of FSDP model specifics. The CUDA event time was 110.85 seconds.
286292

287-
.. code-block::
293+
.. code-block:: bash
294+
288295
python FSDP_mnist.py
289296
290297
CUDA event elapsed time on training loop 40.67462890625sec
291298
292299
Wrapping the model with FSDP, the model will look as follows, we can see the model has been wrapped in one FSDP unit.
293300
Alternatively, we will look at adding the fsdp_auto_wrap_policy next and will discuss the differences.
294301

295-
.. code-block::
302+
.. code-block:: bash
303+
296304
FullyShardedDataParallel(
297305
(_fsdp_wrapped_module): FlattenParamsWrapper(
298306
(_fpw_module): Net(
@@ -331,6 +339,7 @@ If the number of parameters in this layer is smaller than 100, it will be wrappe
331339
Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning for this config in the future. Without an auto tuning tool, it is good to profile your workflow using different auto wrap policies experimentally and find the optimal one.
332340

333341
.. code-block:: python
342+
334343
my_auto_wrap_policy = functools.partial(
335344
default_auto_wrap_policy, min_num_params=20000
336345
)
@@ -342,7 +351,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
342351
343352
Applying the FSDP_auto_wrap_policy, the model would be as follows:
344353

345-
.. code-block::
354+
.. code-block:: bash
346355
347356
FullyShardedDataParallel(
348357
(_fsdp_wrapped_module): FlattenParamsWrapper(
@@ -361,7 +370,8 @@ Applying the FSDP_auto_wrap_policy, the model would be as follows:
361370
)
362371
363372
364-
.. code-block::
373+
.. code-block:: bash
374+
365375
python FSDP_mnist.py
366376
367377
CUDA event elapsed time on training loop 41.89130859375sec
@@ -388,6 +398,7 @@ In 2.4 we just add it to the FSDP wrapper
388398
389399
390400
.. code-block:: python
401+
391402
model = FSDP(model,
392403
fsdp_auto_wrap_policy=my_auto_wrap_policy,
393404
cpu_offload=CPUOffload(offload_params=True))
@@ -396,11 +407,13 @@ In 2.4 we just add it to the FSDP wrapper
396407
Compare it with DDP, if in 2.4 we just normally wrap the model in ddp, saving the changes in “DDP_mnist.py”.
397408
398409
.. code-block:: python
410+
399411
model = Net().to(rank)
400412
model = DDP(model)
401413
402414
403-
.. code-block::
415+
.. code-block:: bash
416+
404417
python DDP_mnist.py
405418
406419
CUDA event elapsed time on training loop 39.77766015625sec

0 commit comments

Comments
 (0)