@@ -219,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
219
219
layer_tp_plan = {
220
220
# Now the input and output of SequenceParallel has Shard(1) layouts,
221
221
# to represent the input/output tensors sharded on the sequence dimension
222
+ " attention_norm" : SequenceParallel(),
222
223
" attention" : PrepareModuleInput(
223
224
input_layouts = (Shard(1 ),),
224
225
desired_input_layouts = (Replicate(),),
@@ -227,15 +228,14 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
227
228
" attention.wk" : ColwiseParallel(),
228
229
" attention.wv" : ColwiseParallel(),
229
230
" attention.wo" : RowwiseParallel(output_layouts = Shard(1 )),
230
- " attention_norm " : SequenceParallel(),
231
+ " ffn_norm " : SequenceParallel(),
231
232
" feed_forward" : PrepareModuleInput(
232
233
input_layouts = (Shard(1 ),),
233
234
desired_input_layouts = (Replicate(),),
234
235
),
235
236
" feed_forward.w1" : ColwiseParallel(),
236
237
" feed_forward.w2" : RowwiseParallel(output_layouts = Shard(1 )),
237
238
" feed_forward.w3" : ColwiseParallel(),
238
- " ffn_norm" : SequenceParallel(),
239
239
}
240
240
241
241
0 commit comments