Skip to content

Commit 4e9296e

Browse files
authored
[doc][c10d] Fixes to FSDP tutorial (#3138)
Summary: Fix up the actual FSDP tutorial to get it running again. https://github.com/pytorch/examples/pull/1297/files This tutorial is refered to in this document. In addition to this, minor fixups in this document. 1. Fix typo in link. 2. Add a grid card with pre-requisites and what you will learn. 3. Add more links to actual FSDP paper. 4. Stop refering to PyTorch nightly, instead just tell reader to get latest PyTorch as FSDP has been released for a while. Test Plan: Render and examine. Reviewers: Subscribers: Tasks: Tags:
1 parent 60e29a0 commit 4e9296e

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

distributed/home.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Learn FSDP
7777

7878
.. grid-item-card:: :octicon:`file-code;1em`
7979
FSDP Advanced
80-
:link: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html?utm_source=distr_landing&utm_medium=FSDP_advanced
80+
:link: https://pytorch.org/tutorials/intermediate/FSDP_advanced_tutorial.html?utm_source=distr_landing&utm_medium=FSDP_advanced
8181
:link-type: url
8282

8383
In this tutorial, you will learn how to fine-tune a HuggingFace (HF) T5

index.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ Welcome to PyTorch Tutorials
763763
:header: Advanced Model Training with Fully Sharded Data Parallel (FSDP)
764764
:card_description: Explore advanced model training with Fully Sharded Data Parallel package.
765765
:image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png
766-
:link: intermediate/FSDP_adavnced_tutorial.html
766+
:link: intermediate/FSDP_advanced_tutorial.html
767767
:tags: Parallel-and-Distributed-Training
768768

769769
.. customcarditem::
@@ -1115,7 +1115,7 @@ Additional Resources
11151115
intermediate/ddp_tutorial
11161116
intermediate/dist_tuto
11171117
intermediate/FSDP_tutorial
1118-
intermediate/FSDP_adavnced_tutorial
1118+
intermediate/FSDP_advanced_tutorial
11191119
intermediate/TCPStore_libuv_backend
11201120
intermediate/TP_tutorial
11211121
intermediate/pipelining_tutorial

intermediate_source/FSDP_adavnced_tutorial.rst renamed to intermediate_source/FSDP_advanced_tutorial.rst

+54-39
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,44 @@ Wright <https://github.com/lessw2020>`__, `Rohan Varma
66
<https://github.com/rohan-varma/>`__, `Yanli Zhao
77
<https://github.com/zhaojuanmao>`__
88

9+
.. grid:: 2
10+
11+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12+
:class-card: card-prerequisites
13+
14+
* PyTorch's Fully Sharded Data Parallel Module: A wrapper for sharding module parameters across
15+
data parallel workers.
16+
17+
18+
19+
20+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
21+
:class-card: card-prerequisites
22+
23+
* PyTorch 1.12 or later
24+
* Read about the `FSDP API <https://pytorch.org/docs/main/fsdp.html>`__.
25+
926

1027
This tutorial introduces more advanced features of Fully Sharded Data Parallel
1128
(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please
1229
refer to the `FSDP getting started tutorial
1330
<https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__.
1431

1532
In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text
16-
summarization as a working example.
33+
summarization as a working example.
1734

1835
The example uses Wikihow and for simplicity, we will showcase the training on a
19-
single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on
20-
large scale FSDP training on a multi-node cluster, please stay tuned for that on
21-
the PyTorch medium channel.
36+
single node, P4dn instance with 8 A100 GPUs. We now have several blog posts (
37+
`(link1), <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__
38+
`(link2) <https://engineering.fb.com/2021/07/15/open-source/fsdp/>`__)
39+
and a `paper <https://arxiv.org/abs/2304.11277>`__ on
40+
large scale FSDP training on a multi-node cluster.
2241

2342
FSDP is a production ready package with focus on ease of use, performance, and
2443
long-term support. One of the main benefits of FSDP is reducing the memory
2544
footprint on each GPU. This enables training of larger models with lower total
2645
memory vs DDP, and leverages the overlap of computation and communication to
27-
train models efficiently.
46+
train models efficiently.
2847
This reduced memory pressure can be leveraged to either train larger models or
2948
increase batch size, potentially helping overall training throughput. You can
3049
read more about PyTorch FSDP `here
@@ -47,21 +66,21 @@ Recap on How FSDP Works
4766

4867
At a high level FDSP works as follow:
4968

50-
*In constructor*
69+
*In the constructor*
5170

5271
* Shard model parameters and each rank only keeps its own shard
5372

54-
*In forward pass*
73+
*In the forward pass*
5574

5675
* Run `all_gather` to collect all shards from all ranks to recover the full
57-
parameter for this FSDP unit Run forward computation
58-
* Discard non-owned parameter shards it has just collected to free memory
76+
parameter for this FSDP unit and run the forward computation
77+
* Discard the non-owned parameter shards it has just collected to free memory
5978

60-
*In backward pass*
79+
*In the backward pass*
6180

6281
* Run `all_gather` to collect all shards from all ranks to recover the full
63-
parameter in this FSDP unit Run backward computation
64-
* Discard non-owned parameters to free memory.
82+
parameter in this FSDP unit and run backward computation
83+
* Discard non-owned parameters to free memory.
6584
* Run reduce_scatter to sync gradients
6685

6786

@@ -80,15 +99,11 @@ examples
8099

81100
*Setup*
82101

83-
1.1 Install PyTorch Nightlies
84-
85-
We will install PyTorch nightlies, as some of the features such as activation
86-
checkpointing is available in nightlies and will be added in next PyTorch
87-
release after 1.12.
102+
1.1 Install the latest PyTorch
88103

89-
.. code-block:: bash
104+
.. code-block:: bash
90105
91-
pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
106+
pip3 install torch torchvision torchaudio
92107
93108
1.2 Dataset Setup
94109

@@ -154,7 +169,7 @@ Next, we add the following code snippets to a Python script “T5_training.py”
154169
import tqdm
155170
from datetime import datetime
156171
157-
1.4 Distributed training setup.
172+
1.4 Distributed training setup.
158173
Here we use two helper functions to initialize the processes for distributed
159174
training, and then to clean up after training completion. In this tutorial, we
160175
are going to use torch elastic, using `torchrun
@@ -191,13 +206,13 @@ metrics.
191206
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
192207
print(f"--> current date and time of run = {date_of_run}")
193208
return date_of_run
194-
209+
195210
def format_metrics_to_gb(item):
196211
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
197212
metric_num = item / g_gigabyte
198213
metric_num = round(metric_num, ndigits=4)
199214
return metric_num
200-
215+
201216
202217
2.2 Define a train function:
203218

@@ -275,7 +290,7 @@ metrics.
275290

276291
.. code-block:: python
277292
278-
293+
279294
def fsdp_main(args):
280295
281296
model, tokenizer = setup_model("t5-base")
@@ -292,7 +307,7 @@ metrics.
292307
293308
294309
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
295-
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
310+
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
296311
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
297312
298313
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
@@ -430,7 +445,7 @@ metrics.
430445

431446
.. code-block:: python
432447
433-
448+
434449
if __name__ == '__main__':
435450
# Training settings
436451
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
@@ -463,7 +478,7 @@ metrics.
463478
464479
To run the the training using torchrun:
465480

466-
.. code-block:: bash
481+
.. code-block:: bash
467482
468483
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
469484
@@ -487,7 +502,7 @@ communication efficient. In PyTorch 1.12, FSDP added this support and now we
487502
have a wrapping policy for transfomers.
488503

489504
It can be created as follows, where the T5Block represents the T5 transformer
490-
layer class (holding MHSA and FFN).
505+
layer class (holding MHSA and FFN).
491506

492507

493508
.. code-block:: python
@@ -499,7 +514,7 @@ layer class (holding MHSA and FFN).
499514
},
500515
)
501516
torch.cuda.set_device(local_rank)
502-
517+
503518
504519
model = FSDP(model,
505520
auto_wrap_policy=t5_auto_wrap_policy)
@@ -513,22 +528,22 @@ Mixed Precision
513528
FSDP supports flexible mixed precision training allowing for arbitrary reduced
514529
precision types (such as fp16 or bfloat16). Currently BFloat16 is only available
515530
on Ampere GPUs, so you need to confirm native support before you use it. On
516-
V100s for example, BFloat16 can still be run but due to it running non-natively,
531+
V100s for example, BFloat16 can still be run but because it runs non-natively,
517532
it can result in significant slowdowns.
518533

519534
To check if BFloat16 is natively supported, you can use the following :
520535

521536
.. code-block:: python
522-
537+
523538
bf16_ready = (
524539
torch.version.cuda
525-
and torch.cuda.is_bf16_supported()
540+
and torch.cuda.is_bf16_supported()
526541
and LooseVersion(torch.version.cuda) >= "11.0"
527542
and dist.is_nccl_available()
528543
and nccl.version() >= (2, 10)
529544
)
530545
531-
One of the advantages of mixed percision in FSDP is providing granular control
546+
One of the advantages of mixed precision in FSDP is providing granular control
532547
over different precision levels for parameters, gradients, and buffers as
533548
follows:
534549

@@ -571,7 +586,7 @@ with the following policy:
571586
.. code-block:: bash
572587
573588
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
574-
589+
575590
576591
In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper:
577592

@@ -604,9 +619,9 @@ CPU-based initialization:
604619
auto_wrap_policy=t5_auto_wrap_policy,
605620
mixed_precision=bfSixteen,
606621
device_id=torch.cuda.current_device())
607-
608622
609-
623+
624+
610625
Sharding Strategy
611626
-----------------
612627
FSDP sharding strategy by default is set to fully shard the model parameters,
@@ -627,7 +642,7 @@ instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows:
627642
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)
628643
629644
This will reduce the communication overhead in FSDP, in this case, it holds full
630-
parameters after forward and through the backwards pass.
645+
parameters after forward and through the backwards pass.
631646
632647
This saves an all_gather during backwards so there is less communication at the
633648
cost of a higher memory footprint. Note that full model params are freed at the
@@ -652,12 +667,12 @@ wrapper in 2.4 as follows:
652667
mixed_precision=bfSixteen,
653668
device_id=torch.cuda.current_device(),
654669
backward_prefetch = BackwardPrefetch.BACKWARD_PRE)
655-
670+
656671
`backward_prefetch` has two modes, `BACKWARD_PRE` and `BACKWARD_POST`.
657672
`BACKWARD_POST` means that the next FSDP unit's params will not be requested
658673
until the current FSDP unit processing is complete, thus minimizing memory
659674
overhead. In some cases, using `BACKWARD_PRE` can increase model training speed
660-
up to 2-10%, with even higher speed improvements noted for larger models.
675+
up to 2-10%, with even higher speed improvements noted for larger models.
661676
662677
Model Checkpoint Saving, by streaming to the Rank0 CPU
663678
------------------------------------------------------
@@ -696,7 +711,7 @@ Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping
696711
policy especially for transformer models, along with mixed precision and
697712
backward prefetch should speed up your training runs. Also, features such as
698713
initializing the model on device, and checkpoint saving via streaming to CPU
699-
should help to avoid OOM error in dealing with large models.
714+
should help to avoid OOM error in dealing with large models.
700715
701716
We are actively working to add new features to FSDP for the next release. If
702717
you have feedback, feature requests, questions or are encountering issues

0 commit comments

Comments
 (0)