Skip to content

[Tensor Parallel] update tutorial to simplify embedding + first transformer block #2872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions intermediate_source/TP_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
)

Now that we have elaborated the sharding plan for each ``TransformerBlock``, there is usually a ``nn.Embedding`` in the first layer and a final ``nn.Linear`` projection layer, where user could choose row-wise or column-wise sharding to the first ``nn.Embedding`` and column-wise sharding to the last ``nn.Linear`` projection layer with proper input and output layouts specified.
Here is an example:

.. code-block:: python

model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)

.. note::
If the model to be partitioned is too large to fit into CPU memory, one could either use ``meta`` device initialization (for example, initialize the model on meta device first, shard the layers, and the materialize the model), or parallelize the ``TransformerBlock`` layer by layer during the Transformer model initialization.
Expand Down Expand Up @@ -203,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
Expand All @@ -211,33 +228,39 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}


One can see we now use ``PrepareModuleInput`` to modify the module input layouts to the Attention and FeedForward layers from ``Shard(1)`` to ``Replicate()``, and mark their output layouts as ``Shard(1)``.
Just like what happens to Tensor Parallelism, one only needs to specify the tensor sharding layouts of the inputs and outputs, and the communication between layers will happen automatically.

Note that with Sequence Parallel, we assume the inputs and outputs of a ``TransformerBlock`` are always sharded on the sequence dimension, so that multiple ``TransformerBlocks`` can be concatenated seamlessly.
The only exception is that the input to the first ``TransformerBlock`` is replicated from the data loaders, so it has to be converted explicitly:
This can be facilitated by explicitly specifying the output of the beginning ``nn.Embedding`` layer and the input of the final ``nn.Linear`` projection layer to be ``Shard(1)``:

.. code-block:: python

model = parallelize_module(
model,
tp_mesh,
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(),),
desired_input_layouts=(Shard(1),),
),
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)


Expand All @@ -263,16 +286,16 @@ To apply Loss Parallel, the model predictions, usually of the shape ``[batch siz
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(),),
desired_input_layouts=(Shard(1),),
),
},
)

Expand Down
Loading