Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Introducing Grid based sharding for EBC
Grid sharding allows sharding a table on both row and column dimension. It is very useful for extremely large embedding tables and is best suited for scaling model sizes. Concretely, it is implemented as a combination of CW and TWRW, where we first create the CW shards for a table and then TWRW each of those shards across a node.
For extensive details please view the design document, this diff summary showcases some key highlights.
Sharding Metadata Setup
Grid sharding seamlessly fits within the current
ShardedEmbeddingTable
config. The planner determines the sharding placements and is fixed to place CW shards across the node in a TWRW manner (meaning some ranks can have empty RW shard). We construct the metadata borrowing from CW and TWRW. We use for metadata likeembedding_dims
,embedding_names
, etc we useper_node
instead ofper_rank
. This change allows us to treat the table over it's node and can reuse existing comms components with slight tweaks in the data preparation and args pased.Comms Design
Combining optimizations from both CW and TWRW to have highly performant code. Sparse feature distribution is done akin to the TWRW process where the input KJT is bucketized and permuted (according to the stagger indices) and AlltoAll to the correponding ranks that require their part of the input.
Similarly for the embedding all to all, we first all reduce within the node to get the embedding lookup for the CW shard on that node. We're able to leverage the intra node comms for this stage. Then the reduce scatter, just like TWRW, where the embeddings are split to each rank to hold it's and its corresponding cross group ranks embeddings which are then shared in a AlltoAll call. Lastly, the
PermutePooledEmbeddingsSplit
callback is called to rearrange the embedding lookup appropriately (cat the CW lookups in the right order).Optimizer Sharding
Fused optimizer sharding is also updated, we needed to fix how row wise optimizer states are constructed since optimizers for CW shards are row wise sharded. For grid sharding, this doesn't work since the row wise shards are repeated for each CW shard as well as we can encounter the uneven row wise case which is not possible in CW sharding. For grid shards the approach is to use a rolling offset from the previous shard which solves for both uneven row wise shards and the repeated CW shards.
NOTE: Bypasses added in planner to pass CI, which are to be removed in forthcoming diff
Reviewed By: dstaay-fb
Differential Revision: D62594442