Skip to content

grid based sharding for EBC #2445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

iamzainhuda
Copy link
Contributor

@iamzainhuda iamzainhuda commented Sep 30, 2024

Summary:

Introducing Grid based sharding for EBC

Grid sharding allows sharding a table on both row and column dimension. It is very useful for extremely large embedding tables and is best suited for scaling model sizes. Concretely, it is implemented as a combination of CW and TWRW, where we first create the CW shards for a table and then TWRW each of those shards across a node.

For extensive details please view the design document, this diff summary showcases some key highlights.

Sharding Metadata Setup

Grid sharding seamlessly fits within the current ShardedEmbeddingTable config. The planner determines the sharding placements and is fixed to place CW shards across the node in a TWRW manner (meaning some ranks can have empty RW shard). We construct the metadata borrowing from CW and TWRW. We use for metadata like embedding_dims, embedding_names, etc we use per_node instead of per_rank. This change allows us to treat the table over it's node and can reuse existing comms components with slight tweaks in the data preparation and args pased.

Comms Design

Combining optimizations from both CW and TWRW to have highly performant code. Sparse feature distribution is done akin to the TWRW process where the input KJT is bucketized and permuted (according to the stagger indices) and AlltoAll to the correponding ranks that require their part of the input.

Similarly for the embedding all to all, we first all reduce within the node to get the embedding lookup for the CW shard on that node. We're able to leverage the intra node comms for this stage. Then the reduce scatter, just like TWRW, where the embeddings are split to each rank to hold it's and its corresponding cross group ranks embeddings which are then shared in a AlltoAll call. Lastly, the PermutePooledEmbeddingsSplit callback is called to rearrange the embedding lookup appropriately (cat the CW lookups in the right order).

Optimizer Sharding

Fused optimizer sharding is also updated, we needed to fix how row wise optimizer states are constructed since optimizers for CW shards are row wise sharded. For grid sharding, this doesn't work since the row wise shards are repeated for each CW shard as well as we can encounter the uneven row wise case which is not possible in CW sharding. For grid shards the approach is to use a rolling offset from the previous shard which solves for both uneven row wise shards and the repeated CW shards.

NOTE: Bypasses added in planner to pass CI, which are to be removed in forthcoming diff

Reviewed By: dstaay-fb

Differential Revision: D62594442

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 30, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62594442

Summary:
Pull Request resolved: pytorch#2445

## Introducing Grid based sharding for EBC
This is a form of CW sharding and then TWRW sharding the respective CW shards.

One of the key changes is how the metadata from sharding placements is constructed in grid sharding. We leverage the concept of `per_node` from TWRW and combine it with the permutations and concatenation required in CW.

#### Comms Design
Combining optimizations from both CW and TWRW to have highly performant code. Sparse feature distribution is done akin to the TWRW process where the input KJT is bucketized and permuted (according to the stagger indices) and AlltoAll to the correponding ranks that require their part of the input.

Similarly for the embedding all to all, we first all reduce within the node to get the embedding lookup for the CW shard on that node. We're able to leverage the intra node comms for this stage.  Then the reduce scatter, just like TWRW, where the embeddings are split to each rank to hold it's and its corresponding cross group ranks embeddings which are then shared in a AlltoAll call. Lastly, the `PermutePooledEmbeddingsSplit` callback is called to rearrange the embedding lookup appropriately (cat the CW lookups in the right order).

#### Optimizer Sharding
Fused optimizer sharding is also updated, we needed to fix how row wise optimizer states are constructed since optimizers for CW shards are row wise sharded. For grid sharding, this doesn't work since the row wise shards are repeated for each CW shard as well as we can encounter the uneven row wise case which is not possible in CW sharding. For grid shards the approach is to use a rolling offset from the previous shard which solves for both uneven row wise shards and the repeated CW shards.

NOTE: Bypasses added in planner to pass CI, which are to be removed in forthcoming diff

Reviewed By: dstaay-fb

Differential Revision: D62594442
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D62594442

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants