@@ -6,25 +6,44 @@ Wright <https://github.com/lessw2020>`__, `Rohan Varma
6
6
<https://github.com/rohan-varma/> `__, `Yanli Zhao
7
7
<https://github.com/zhaojuanmao> `__
8
8
9
+ .. grid :: 2
10
+
11
+ .. grid-item-card :: :octicon:`mortar-board;1em;` What you will learn
12
+ :class-card: card-prerequisites
13
+
14
+ * PyTorch's Fully Sharded Data Parallel Module: A wrapper for sharding module parameters across
15
+ data parallel workers.
16
+
17
+
18
+
19
+
20
+ .. grid-item-card :: :octicon:`list-unordered;1em;` Prerequisites
21
+ :class-card: card-prerequisites
22
+
23
+ * PyTorch 1.12 or later
24
+ * Read about the `FSDP API <https://pytorch.org/docs/main/fsdp.html >`__.
25
+
9
26
10
27
This tutorial introduces more advanced features of Fully Sharded Data Parallel
11
28
(FSDP) as part of the PyTorch 1.12 release. To get familiar with FSDP, please
12
29
refer to the `FSDP getting started tutorial
13
30
<https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html> `__.
14
31
15
32
In this tutorial, we fine-tune a HuggingFace (HF) T5 model with FSDP for text
16
- summarization as a working example.
33
+ summarization as a working example.
17
34
18
35
The example uses Wikihow and for simplicity, we will showcase the training on a
19
- single node, P4dn instance with 8 A100 GPUs. We will soon have a blog post on
20
- large scale FSDP training on a multi-node cluster, please stay tuned for that on
21
- the PyTorch medium channel.
36
+ single node, P4dn instance with 8 A100 GPUs. We now have several blog posts (
37
+ `(link1), <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ >`__
38
+ `(link2) <https://engineering.fb.com/2021/07/15/open-source/fsdp/ >`__)
39
+ and a `paper <https://arxiv.org/abs/2304.11277 >`__ on
40
+ large scale FSDP training on a multi-node cluster.
22
41
23
42
FSDP is a production ready package with focus on ease of use, performance, and
24
43
long-term support. One of the main benefits of FSDP is reducing the memory
25
44
footprint on each GPU. This enables training of larger models with lower total
26
45
memory vs DDP, and leverages the overlap of computation and communication to
27
- train models efficiently.
46
+ train models efficiently.
28
47
This reduced memory pressure can be leveraged to either train larger models or
29
48
increase batch size, potentially helping overall training throughput. You can
30
49
read more about PyTorch FSDP `here
@@ -47,21 +66,21 @@ Recap on How FSDP Works
47
66
48
67
At a high level FDSP works as follow:
49
68
50
- *In constructor *
69
+ *In the constructor *
51
70
52
71
* Shard model parameters and each rank only keeps its own shard
53
72
54
- *In forward pass *
73
+ *In the forward pass *
55
74
56
75
* Run `all_gather ` to collect all shards from all ranks to recover the full
57
- parameter for this FSDP unit Run forward computation
58
- * Discard non-owned parameter shards it has just collected to free memory
76
+ parameter for this FSDP unit and run the forward computation
77
+ * Discard the non-owned parameter shards it has just collected to free memory
59
78
60
- *In backward pass *
79
+ *In the backward pass *
61
80
62
81
* Run `all_gather ` to collect all shards from all ranks to recover the full
63
- parameter in this FSDP unit Run backward computation
64
- * Discard non-owned parameters to free memory.
82
+ parameter in this FSDP unit and run backward computation
83
+ * Discard non-owned parameters to free memory.
65
84
* Run reduce_scatter to sync gradients
66
85
67
86
@@ -80,15 +99,11 @@ examples
80
99
81
100
*Setup *
82
101
83
- 1.1 Install PyTorch Nightlies
84
-
85
- We will install PyTorch nightlies, as some of the features such as activation
86
- checkpointing is available in nightlies and will be added in next PyTorch
87
- release after 1.12.
102
+ 1.1 Install the latest PyTorch
88
103
89
- .. code-block :: bash
104
+ .. code-block :: bash
90
105
91
- pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
106
+ pip3 install torch torchvision torchaudio
92
107
93
108
1.2 Dataset Setup
94
109
@@ -154,7 +169,7 @@ Next, we add the following code snippets to a Python script “T5_training.py”
154
169
import tqdm
155
170
from datetime import datetime
156
171
157
- 1.4 Distributed training setup.
172
+ 1.4 Distributed training setup.
158
173
Here we use two helper functions to initialize the processes for distributed
159
174
training, and then to clean up after training completion. In this tutorial, we
160
175
are going to use torch elastic, using `torchrun
@@ -191,13 +206,13 @@ metrics.
191
206
date_of_run = datetime.now().strftime(" %Y-%m-%d -%I:%M:%S_%p" )
192
207
print (f " --> current date and time of run = { date_of_run} " )
193
208
return date_of_run
194
-
209
+
195
210
def format_metrics_to_gb (item ):
196
211
""" quick function to format numbers to gigabyte and round to 4 digit precision"""
197
212
metric_num = item / g_gigabyte
198
213
metric_num = round (metric_num, ndigits = 4 )
199
214
return metric_num
200
-
215
+
201
216
202
217
2.2 Define a train function:
203
218
@@ -275,7 +290,7 @@ metrics.
275
290
276
291
.. code-block :: python
277
292
278
-
293
+
279
294
def fsdp_main (args ):
280
295
281
296
model, tokenizer = setup_model(" t5-base" )
@@ -292,7 +307,7 @@ metrics.
292
307
293
308
294
309
# wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
295
- train_dataset = wikihow(tokenizer, ' train' , 1500 , 512 , 150 , False )
310
+ train_dataset = wikihow(tokenizer, ' train' , 1500 , 512 , 150 , False )
296
311
val_dataset = wikihow(tokenizer, ' validation' , 300 , 512 , 150 , False )
297
312
298
313
sampler1 = DistributedSampler(train_dataset, rank = rank, num_replicas = world_size, shuffle = True )
@@ -430,7 +445,7 @@ metrics.
430
445
431
446
.. code-block :: python
432
447
433
-
448
+
434
449
if __name__ == ' __main__' :
435
450
# Training settings
436
451
parser = argparse.ArgumentParser(description = ' PyTorch T5 FSDP Example' )
@@ -463,7 +478,7 @@ metrics.
463
478
464
479
To run the the training using torchrun:
465
480
466
- .. code-block :: bash
481
+ .. code-block :: bash
467
482
468
483
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
469
484
@@ -487,7 +502,7 @@ communication efficient. In PyTorch 1.12, FSDP added this support and now we
487
502
have a wrapping policy for transfomers.
488
503
489
504
It can be created as follows, where the T5Block represents the T5 transformer
490
- layer class (holding MHSA and FFN).
505
+ layer class (holding MHSA and FFN).
491
506
492
507
493
508
.. code-block :: python
@@ -499,7 +514,7 @@ layer class (holding MHSA and FFN).
499
514
},
500
515
)
501
516
torch.cuda.set_device(local_rank)
502
-
517
+
503
518
504
519
model = FSDP(model,
505
520
auto_wrap_policy = t5_auto_wrap_policy)
@@ -513,22 +528,22 @@ Mixed Precision
513
528
FSDP supports flexible mixed precision training allowing for arbitrary reduced
514
529
precision types (such as fp16 or bfloat16). Currently BFloat16 is only available
515
530
on Ampere GPUs, so you need to confirm native support before you use it. On
516
- V100s for example, BFloat16 can still be run but due to it running non-natively,
531
+ V100s for example, BFloat16 can still be run but because it runs non-natively,
517
532
it can result in significant slowdowns.
518
533
519
534
To check if BFloat16 is natively supported, you can use the following :
520
535
521
536
.. code-block :: python
522
-
537
+
523
538
bf16_ready = (
524
539
torch.version.cuda
525
- and torch.cuda.is_bf16_supported()
540
+ and torch.cuda.is_bf16_supported()
526
541
and LooseVersion(torch.version.cuda) >= " 11.0"
527
542
and dist.is_nccl_available()
528
543
and nccl.version() >= (2 , 10 )
529
544
)
530
545
531
- One of the advantages of mixed percision in FSDP is providing granular control
546
+ One of the advantages of mixed precision in FSDP is providing granular control
532
547
over different precision levels for parameters, gradients, and buffers as
533
548
follows:
534
549
@@ -571,7 +586,7 @@ with the following policy:
571
586
.. code-block :: bash
572
587
573
588
grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)
574
-
589
+
575
590
576
591
In 2.4 we just add the relevant mixed precision policy to the FSDP wrapper:
577
592
@@ -604,9 +619,9 @@ CPU-based initialization:
604
619
auto_wrap_policy = t5_auto_wrap_policy,
605
620
mixed_precision = bfSixteen,
606
621
device_id = torch.cuda.current_device())
607
-
608
622
609
-
623
+
624
+
610
625
Sharding Strategy
611
626
-----------------
612
627
FSDP sharding strategy by default is set to fully shard the model parameters,
@@ -627,7 +642,7 @@ instead of "ShardingStrategy.FULL_SHARD" to the FSDP initialization as follows:
627
642
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP # ZERO2)
628
643
629
644
This will reduce the communication overhead in FSDP , in this case, it holds full
630
- parameters after forward and through the backwards pass .
645
+ parameters after forward and through the backwards pass .
631
646
632
647
This saves an all_gather during backwards so there is less communication at the
633
648
cost of a higher memory footprint. Note that full model params are freed at the
@@ -652,12 +667,12 @@ wrapper in 2.4 as follows:
652
667
mixed_precision = bfSixteen,
653
668
device_id = torch.cuda.current_device(),
654
669
backward_prefetch = BackwardPrefetch.BACKWARD_PRE )
655
-
670
+
656
671
`backward_prefetch` has two modes, `BACKWARD_PRE ` and `BACKWARD_POST ` .
657
672
`BACKWARD_POST ` means that the next FSDP unit' s params will not be requested
658
673
until the current FSDP unit processing is complete, thus minimizing memory
659
674
overhead. In some cases, using `BACKWARD_PRE ` can increase model training speed
660
- up to 2 - 10 % , with even higher speed improvements noted for larger models.
675
+ up to 2 - 10 % , with even higher speed improvements noted for larger models.
661
676
662
677
Model Checkpoint Saving, by streaming to the Rank0 CPU
663
678
------------------------------------------------------
@@ -696,7 +711,7 @@ Pytorch 1.12 and used HF T5 as the running example. Using the proper wrapping
696
711
policy especially for transformer models, along with mixed precision and
697
712
backward prefetch should speed up your training runs. Also, features such as
698
713
initializing the model on device, and checkpoint saving via streaming to CPU
699
- should help to avoid OOM error in dealing with large models.
714
+ should help to avoid OOM error in dealing with large models.
700
715
701
716
We are actively working to add new features to FSDP for the next release. If
702
717
you have feedback, feature requests, questions or are encountering issues
0 commit comments