Skip to content

Commit 13e7981

Browse files
authored
[Tensor Parallel] update tutorial to simplify embedding + first transformer block (#2872)
1 parent 6424883 commit 13e7981

File tree

1 file changed

+35
-12
lines changed

1 file changed

+35
-12
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,22 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
164164
)
165165
166166
Now that we have elaborated the sharding plan for each ``TransformerBlock``, there is usually a ``nn.Embedding`` in the first layer and a final ``nn.Linear`` projection layer, where user could choose row-wise or column-wise sharding to the first ``nn.Embedding`` and column-wise sharding to the last ``nn.Linear`` projection layer with proper input and output layouts specified.
167+
Here is an example:
168+
169+
.. code-block:: python
170+
171+
model = parallelize_module(
172+
model,
173+
tp_mesh,
174+
{
175+
"tok_embeddings": RowwiseParallel(
176+
input_layouts=Replicate(),
177+
),
178+
"output": ColwiseParallel(
179+
output_layouts=Replicate(),
180+
),
181+
}
182+
)
167183
168184
.. note::
169185
If the model to be partitioned is too large to fit into CPU memory, one could either use ``meta`` device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize the ``TransformerBlock`` layer by layer during the Transformer model initialization.
@@ -203,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
203219
layer_tp_plan = {
204220
# Now the input and output of SequenceParallel has Shard(1) layouts,
205221
# to represent the input/output tensors sharded on the sequence dimension
222+
"attention_norm": SequenceParallel(),
206223
"attention": PrepareModuleInput(
207224
input_layouts=(Shard(1),),
208225
desired_input_layouts=(Replicate(),),
@@ -211,33 +228,39 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
211228
"attention.wk": ColwiseParallel(),
212229
"attention.wv": ColwiseParallel(),
213230
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
214-
"attention_norm": SequenceParallel(),
231+
"ffn_norm": SequenceParallel(),
215232
"feed_forward": PrepareModuleInput(
216233
input_layouts=(Shard(1),),
217234
desired_input_layouts=(Replicate(),),
218235
),
219236
"feed_forward.w1": ColwiseParallel(),
220237
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
221238
"feed_forward.w3": ColwiseParallel(),
222-
"ffn_norm": SequenceParallel(),
223239
}
224240
225241
226242
One can see we now use ``PrepareModuleInput`` to modify the module input layouts to the Attention and FeedForward layers from ``Shard(1)`` to ``Replicate()``, and mark their output layouts as ``Shard(1)``.
227243
Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically.
228244

229245
Note that with Sequence Parallel, we assume the inputs and outputs of a ``TransformerBlock`` are always sharded on the sequence dimension, so that multiple ``TransformerBlocks`` can be concatenated seamlessly.
230-
The only exception is that the input to the first ``TransformerBlock`` is replicated from the data loaders, so it has to be converted explicitly:
246+
This can be facilitated by explicitly specifying the output of the beginning ``nn.Embedding`` layer and the input of the final ``nn.Linear`` projection layer to be ``Shard(1)``:
231247

232248
.. code-block:: python
233249
234250
model = parallelize_module(
235251
model,
236252
tp_mesh,
237-
"layers.0": PrepareModuleInput(
238-
input_layouts=(Replicate(),),
239-
desired_input_layouts=(Shard(1),),
240-
),
253+
{
254+
"tok_embeddings": RowwiseParallel(
255+
input_layouts=Replicate(),
256+
output_layouts=Shard(1),
257+
),
258+
"norm": SequenceParallel(),
259+
"output": ColwiseParallel(
260+
input_layouts=Shard(1),
261+
output_layouts=Replicate()
262+
),
263+
}
241264
)
242265
243266
@@ -263,16 +286,16 @@ To apply Loss Parallel, the model predictions, usually of the shape ``[batch siz
263286
model,
264287
tp_mesh,
265288
{
289+
"tok_embeddings": RowwiseParallel(
290+
input_layouts=Replicate(),
291+
output_layouts=Shard(1),
292+
),
293+
"norm": SequenceParallel(),
266294
"output": ColwiseParallel(
267295
input_layouts=Shard(1),
268296
# use DTensor as the output
269297
use_local_output=False,
270298
),
271-
"norm": SequenceParallel(),
272-
"layers.0": PrepareModuleInput(
273-
input_layouts=(Replicate(),),
274-
desired_input_layouts=(Shard(1),),
275-
),
276299
},
277300
)
278301

0 commit comments

Comments
 (0)