Skip to content

Commit a4cd3e5

Browse files
committed
Update on "[Tensor Parallel] update tutorial to simplify embedding + first transformer block"
cross PR with pytorch/examples#1259 [ghstack-poisoned]
1 parent 94553e2 commit a4cd3e5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
219219
layer_tp_plan = {
220220
# Now the input and output of SequenceParallel has Shard(1) layouts,
221221
# to represent the input/output tensors sharded on the sequence dimension
222+
"attention_norm": SequenceParallel(),
222223
"attention": PrepareModuleInput(
223224
input_layouts=(Shard(1),),
224225
desired_input_layouts=(Replicate(),),
@@ -227,15 +228,14 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
227228
"attention.wk": ColwiseParallel(),
228229
"attention.wv": ColwiseParallel(),
229230
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
230-
"attention_norm": SequenceParallel(),
231+
"ffn_norm": SequenceParallel(),
231232
"feed_forward": PrepareModuleInput(
232233
input_layouts=(Shard(1),),
233234
desired_input_layouts=(Replicate(),),
234235
),
235236
"feed_forward.w1": ColwiseParallel(),
236237
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
237238
"feed_forward.w3": ColwiseParallel(),
238-
"ffn_norm": SequenceParallel(),
239239
}
240240
241241

0 commit comments

Comments
 (0)