You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/FSDP_tutorial.rst
+28-15
Original file line number
Diff line number
Diff line change
@@ -9,7 +9,7 @@ It also comes with considerable engineering complexity to handle the training of
9
9
`Pytorch FSDP <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__, released in PyTorch 1.11 makes this easier.
10
10
11
11
In this tutorial, we show how to use `FSDP APIs <https://pytorch.org/docs/1.11/fsdp.html>`__, for simple MNIST models that can be extended to other larger models such as `HuggingFace BERT models <https://huggingface.co/blog/zero-deepspeed-fairscale>`__,
12
-
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d>`__ . The sample DDP MNIST code has been borrowed from `here <https://github.com/yqhu/mnist_examples>`__.
12
+
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ . The sample DDP MNIST code has been borrowed from `here <https://github.com/yqhu/mnist_examples>`__.
13
13
14
14
15
15
How FSDP works
@@ -28,18 +28,21 @@ FSDP GPU memory footprint would be smaller than DDP across all workers. This mak
28
28
At high level FDSP works as follow:
29
29
30
30
*In constructor*
31
-
Shard model parameters and each rank only keeps its own shard
31
+
32
+
* Shard model parameters and each rank only keeps its own shard
32
33
33
34
*In forward path*
34
-
Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
35
-
Run forward computation
36
-
Discard parameter shards it has just collected
35
+
36
+
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
37
+
* Run forward computation
38
+
* Discard parameter shards it has just collected
37
39
38
40
*In backward path*
39
-
Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
40
-
Run backward computation
41
-
Run reduce_scatter to sync gradients
42
-
Discard parameters.
41
+
42
+
* Run allgather to collect all shards from all ranks to recover the full parameter in this FSDP unit
43
+
* Run backward computation
44
+
* Run reduce_scatter to sync gradients
45
+
* Discard parameters.
43
46
44
47
How to use FSDP
45
48
--------------
@@ -49,7 +52,8 @@ Here we use a toy model to run training on MNIST dataset for demonstration purpo
@@ -155,6 +160,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
155
160
2.3 Define a validation function
156
161
157
162
.. code-block:: python
163
+
158
164
deftest(model, rank, world_size, test_loader):
159
165
model.eval()
160
166
correct =0
@@ -284,15 +290,17 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
284
290
285
291
We have recorded cuda events to measure the time of FSDP model specifics. The CUDA event time was 110.85 seconds.
286
292
287
-
.. code-block::
293
+
.. code-block:: bash
294
+
288
295
python FSDP_mnist.py
289
296
290
297
CUDA event elapsed time on training loop 40.67462890625sec
291
298
292
299
Wrapping the model with FSDP, the model will look as follows, we can see the model has been wrapped in one FSDP unit.
293
300
Alternatively, we will look at adding the fsdp_auto_wrap_policy next and will discuss the differences.
294
301
295
-
.. code-block::
302
+
.. code-block:: bash
303
+
296
304
FullyShardedDataParallel(
297
305
(_fsdp_wrapped_module): FlattenParamsWrapper(
298
306
(_fpw_module): Net(
@@ -331,6 +339,7 @@ If the number of parameters in this layer is smaller than 100, it will be wrappe
331
339
Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning for this config in the future. Without an auto tuning tool, it is good to profile your workflow using different auto wrap policies experimentally and find the optimal one.
332
340
333
341
.. code-block:: python
342
+
334
343
my_auto_wrap_policy = functools.partial(
335
344
default_auto_wrap_policy, min_num_params=20000
336
345
)
@@ -342,7 +351,7 @@ Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning
342
351
343
352
Applying the FSDP_auto_wrap_policy, the model would be as follows:
344
353
345
-
.. code-block::
354
+
.. code-block::bash
346
355
347
356
FullyShardedDataParallel(
348
357
(_fsdp_wrapped_module): FlattenParamsWrapper(
@@ -361,7 +370,8 @@ Applying the FSDP_auto_wrap_policy, the model would be as follows:
361
370
)
362
371
363
372
364
-
.. code-block::
373
+
.. code-block:: bash
374
+
365
375
python FSDP_mnist.py
366
376
367
377
CUDA event elapsed time on training loop 41.89130859375sec
@@ -388,6 +398,7 @@ In 2.4 we just add it to the FSDP wrapper
388
398
389
399
390
400
.. code-block:: python
401
+
391
402
model = FSDP(model,
392
403
fsdp_auto_wrap_policy=my_auto_wrap_policy,
393
404
cpu_offload=CPUOffload(offload_params=True))
@@ -396,11 +407,13 @@ In 2.4 we just add it to the FSDP wrapper
396
407
Compare it with DDP, if in 2.4 we just normally wrap the model in ddp, saving the changes in “DDP_mnist.py”.
397
408
398
409
.. code-block:: python
410
+
399
411
model = Net().to(rank)
400
412
model = DDP(model)
401
413
402
414
403
-
.. code-block::
415
+
.. code-block:: bash
416
+
404
417
python DDP_mnist.py
405
418
406
419
CUDA event elapsed time on training loop 39.77766015625sec
0 commit comments