-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathgpu_intermediate.rst
194 lines (127 loc) · 6.44 KB
/
gpu_intermediate.rst
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
:orphan:
.. _gpu_intermediate:
GPU training (Intermediate)
===========================
**Audience:** Users looking to train across machines or experiment with different scaling techniques.
----
Distributed training strategies
-------------------------------
Lightning supports multiple ways of doing distributed training.
- Regular (``strategy='ddp'``)
- Spawn (``strategy='ddp_spawn'``)
- Notebook/Fork (``strategy='ddp_notebook'``)
.. video:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/yt/Trainer+flags+4-+multi+node+training_3.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/yt_thumbs/thumb_multi_gpus.png
:width: 400
.. note::
If you request multiple GPUs or nodes without setting a strategy, DDP will be automatically used.
----
Distributed Data Parallel
^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`~torch.nn.parallel.DistributedDataParallel` (DDP) works as follows:
1. Each GPU across each node gets its own process.
2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.
3. Each process inits the model.
4. Each process performs a full forward and backward pass in parallel.
5. The gradients are synced and averaged across all processes.
6. Each process updates its optimizer.
|
.. code-block:: python
# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp")
# train on 32 GPUs (4 nodes)
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp", num_nodes=4)
This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment
variables:
.. code-block:: bash
# example for 3 GPUs DDP
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python my_file.py --accelerator 'gpu' --devices 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=1 python my_file.py --accelerator 'gpu' --devices 3 --etc
MASTER_ADDR=localhost MASTER_PORT=random() WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=2 python my_file.py --accelerator 'gpu' --devices 3 --etc
Using DDP this way has a few disadvantages over ``torch.multiprocessing.spawn()``:
1. All processes (including the main process) participate in training and have the updated state of the model and Trainer state.
2. No multiprocessing pickle errors
3. Easily scales to multi-node training
|
It is NOT possible to use DDP in interactive environments like Jupyter Notebook, Google COLAB, Kaggle, etc.
In these situations you should use `ddp_notebook`.
----
Distributed Data Parallel Spawn
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. warning:: It is STRONGLY recommended to use DDP for speed and performance.
The `ddp_spawn` strategy is similar to `ddp` except that it uses ``torch.multiprocessing.spawn()`` to start the training processes.
Use this for debugging only, or if you are converting a code base to Lightning that relies on spawn.
.. code-block:: python
# train on 8 GPUs (same machine (ie: node))
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp_spawn")
We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):
1. After ``.fit()``, only the model's weights get restored to the main process, but no other state of the Trainer.
2. Does not support multi-node training.
3. It is generally slower than DDP.
----
Distributed Data Parallel in Notebooks
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
DDP Notebook/Fork is an alternative to Spawn that can be used in interactive Python and Jupyter notebooks, Google Colab, Kaggle notebooks, and so on:
The Trainer enables it by default when such environments are detected.
.. code-block:: python
# train on 8 GPUs in a Jupyter notebook
trainer = Trainer(accelerator="gpu", devices=8)
# can be set explicitly
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp_notebook")
# can also be used in non-interactive environments
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp_fork")
Among the native distributed strategies, regular DDP (``strategy="ddp"``) is still recommended as the go-to strategy over Spawn and Fork/Notebook for its speed and stability but it can only be used with scripts.
----
Comparison of DDP variants and tradeoffs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table:: DDP variants and their tradeoffs
:widths: 40 20 20 20
:header-rows: 1
* -
- DDP
- DDP Spawn
- DDP Notebook/Fork
* - Works in Jupyter notebooks / IPython environments
- No
- No
- Yes
* - Supports multi-node
- Yes
- Yes
- Yes
* - Supported platforms
- Linux, Mac, Win
- Linux, Mac, Win
- Linux, Mac
* - Requires all objects to be picklable
- No
- Yes
- No
* - Limitations in the main process
- None
- The state of objects is not up-to-date after returning to the main process (`Trainer.fit()` etc). Only the model parameters get transferred over.
- GPU operations such as moving tensors to the GPU or calling ``torch.cuda`` functions before invoking ``Trainer.fit`` is not allowed.
* - Process creation time
- Slow
- Slow
- Fast
----
TorchRun (TorchElastic)
-----------------------
Lightning supports the use of TorchRun (previously known as TorchElastic) to enable fault-tolerant and elastic distributed job scheduling.
To use it, specify the DDP strategy and the number of GPUs you want to use in the Trainer.
.. code-block:: python
Trainer(accelerator="gpu", devices=8, strategy="ddp")
Then simply launch your script with the :doc:`torchrun <../clouds/cluster_intermediate_2>` command.
----
Optimize multi-machine communication
------------------------------------
By default, Lightning will select the ``nccl`` backend over ``gloo`` when running on GPUs.
Find more information about PyTorch's supported backends `here <https://pytorch.org/docs/stable/distributed.html>`__.
Lightning allows explicitly specifying the backend via the `process_group_backend` constructor argument on the relevant Strategy classes. By default, Lightning will select the appropriate process group backend based on the hardware used.
.. code-block:: python
from lightning.pytorch.strategies import DDPStrategy
# Explicitly specify the process group backend if you choose to
ddp = DDPStrategy(process_group_backend="nccl")
# Configure the strategy on the Trainer
trainer = Trainer(strategy=ddp, accelerator="gpu", devices=8)