You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/TP_tutorial.rst
+35-12Lines changed: 35 additions & 12 deletions
Original file line number
Diff line number
Diff line change
@@ -164,6 +164,22 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
164
164
)
165
165
166
166
Now that we have elaborated the sharding plan for each ``TransformerBlock``, there is usually a ``nn.Embedding`` in the first layer and a final ``nn.Linear`` projection layer, where user could choose row-wise or column-wise sharding to the first ``nn.Embedding`` and column-wise sharding to the last ``nn.Linear`` projection layer with proper input and output layouts specified.
167
+
Here is an example:
168
+
169
+
.. code-block:: python
170
+
171
+
model = parallelize_module(
172
+
model,
173
+
tp_mesh,
174
+
{
175
+
"tok_embeddings": RowwiseParallel(
176
+
input_layouts=Replicate(),
177
+
),
178
+
"output": ColwiseParallel(
179
+
output_layouts=Replicate(),
180
+
),
181
+
}
182
+
)
167
183
168
184
.. note::
169
185
If the model to be partitioned is too large to fit into CPU memory, one could either use ``meta`` device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize the ``TransformerBlock`` layer by layer during the Transformer model initialization.
@@ -203,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
203
219
layer_tp_plan = {
204
220
# Now the input and output of SequenceParallel has Shard(1) layouts,
205
221
# to represent the input/output tensors sharded on the sequence dimension
222
+
"attention_norm": SequenceParallel(),
206
223
"attention": PrepareModuleInput(
207
224
input_layouts=(Shard(1),),
208
225
desired_input_layouts=(Replicate(),),
@@ -211,33 +228,39 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
One can see we now use ``PrepareModuleInput`` to modify the module input layouts to the Attention and FeedForward layers from ``Shard(1)`` to ``Replicate()``, and mark their output layouts as ``Shard(1)``.
227
243
Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically.
228
244
229
245
Note that with Sequence Parallel, we assume the inputs and outputs of a ``TransformerBlock`` are always sharded on the sequence dimension, so that multiple ``TransformerBlocks`` can be concatenated seamlessly.
230
-
The only exception is that the input to the first ``TransformerBlock`` is replicated from the data loaders, so it has to be converted explicitly:
246
+
This can be facilitated by explicitly specifying the output of the beginning ``nn.Embedding`` layer and the input of the final ``nn.Linear`` projection layer to be ``Shard(1)``:
231
247
232
248
.. code-block:: python
233
249
234
250
model = parallelize_module(
235
251
model,
236
252
tp_mesh,
237
-
"layers.0": PrepareModuleInput(
238
-
input_layouts=(Replicate(),),
239
-
desired_input_layouts=(Shard(1),),
240
-
),
253
+
{
254
+
"tok_embeddings": RowwiseParallel(
255
+
input_layouts=Replicate(),
256
+
output_layouts=Shard(1),
257
+
),
258
+
"norm": SequenceParallel(),
259
+
"output": ColwiseParallel(
260
+
input_layouts=Shard(1),
261
+
output_layouts=Replicate()
262
+
),
263
+
}
241
264
)
242
265
243
266
@@ -263,16 +286,16 @@ To apply Loss Parallel, the model predictions, usually of the shape ``[batch siz
0 commit comments