diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index d4b294b55..dc05d6027 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -204,6 +204,7 @@ def bucketize_kjt_before_all2all( kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, + total_num_blocks: Optional[torch.Tensor] = None, output_permute: bool = False, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, @@ -219,6 +220,7 @@ def bucketize_kjt_before_all2all( Args: num_buckets (int): number of buckets to bucketize the values into. block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. + total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization output_permute (bool): output the memory location mapping from the unbucketized values to bucketized values or not. bucketize_pos (bool): output the changed position of the bucketized values or @@ -235,7 +237,7 @@ def bucketize_kjt_before_all2all( block_sizes.numel() == num_features, f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", ) - block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()) + ( bucketized_lengths, bucketized_indices, @@ -247,14 +249,24 @@ def bucketize_kjt_before_all2all( kjt.values(), bucketize_pos=bucketize_pos, sequence=output_permute, - block_sizes=block_sizes_new_type, + block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()), + total_num_blocks=( + _fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values()) + if total_num_blocks is not None + else None + ), my_size=num_buckets, weights=kjt.weights_or_none(), batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), - block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths() + block_bucketize_pos=( + _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) + if block_bucketize_row_pos is not None + else None + ), keep_orig_idx=keep_original_indices, ) + return ( KeyedJaggedTensor( # duplicate keys will be resolved by AllToAll diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index e513b3e35..e397ea29b 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -389,19 +389,35 @@ def _create_input_dists( input_feature_names: List[str], ) -> None: for sharding, sharding_features in zip( - self._embedding_shardings, self._sharding_features + self._embedding_shardings, + self._sharding_features, ): assert isinstance(sharding, BaseRwEmbeddingSharding) - feature_hash_sizes: List[int] = [ + feature_num_buckets: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].buckets() + for f in sharding_features + ] + + input_sizes: List[int] = [ self._managed_collision_modules[self._feature_to_table[f]].input_size() for f in sharding_features ] + feature_hash_sizes: List[int] = [] + feature_total_num_buckets: List[int] = [] + for input_size, num_buckets in zip( + input_sizes, + feature_num_buckets, + ): + feature_hash_sizes.append(input_size) + feature_total_num_buckets.append(num_buckets) + input_dist = RwSparseFeaturesDist( # pyre-ignore [6] pg=sharding._pg, num_features=sharding._get_num_features(), feature_hash_sizes=feature_hash_sizes, + feature_total_num_buckets=feature_total_num_buckets, device=sharding._device, is_sequence=True, has_feature_processor=sharding._has_feature_processor, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ccba69a78..ce60153c9 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -279,6 +279,7 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): communication. num_features (int): total number of features. feature_hash_sizes (List[int]): hash sizes of features. + feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size. device (Optional[torch.device]): device on which buffers will be allocated. is_sequence (bool): if this is for a sequence embedding. has_feature_processor (bool): existence of feature processor (ie. position @@ -291,6 +292,7 @@ def __init__( pg: dist.ProcessGroup, num_features: int, feature_hash_sizes: List[int], + feature_total_num_buckets: Optional[List[int]] = None, device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False, @@ -300,10 +302,16 @@ def __init__( super().__init__() self._world_size: int = pg.size() self._num_features = num_features - feature_block_sizes = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in feature_hash_sizes - ] + + feature_block_sizes: List[int] = [] + + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if feature_total_num_buckets is not None: + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) + self.register_buffer( "_feature_block_sizes_tensor", torch.tensor( @@ -311,7 +319,22 @@ def __init__( device=device, dtype=torch.int64, ), + persistent=False, ) + self._has_multiple_blocks_per_shard: bool = ( + feature_total_num_buckets is not None + ) + if self._has_multiple_blocks_per_shard: + self.register_buffer( + "_feature_total_num_blocks_tensor", + torch.tensor( + [feature_total_num_buckets], + device=device, + dtype=torch.int64, + ), + persistent=False, + ) + self._dist = KJTAllToAll( pg=pg, splits=[self._num_features] * self._world_size, @@ -345,6 +368,11 @@ def forward( sparse_features, num_buckets=self._world_size, block_sizes=self._feature_block_sizes_tensor, + total_num_blocks=( + self._feature_total_num_blocks_tensor + if self._has_multiple_blocks_per_shard + else None + ), output_permute=self._is_sequence, bucketize_pos=( self._has_feature_processor diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 135d7f47a..3e299192e 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -336,7 +336,9 @@ def test_kjt_bucketize_before_all2all( block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt, world_size, block_sizes, False, False + kjt=kjt, + num_buckets=world_size, + block_sizes=block_sizes, ) expected_block_bucketized_kjt = block_bucketize_ref( @@ -433,7 +435,10 @@ def test_kjt_bucketize_before_all2all_cpu( """ block_sizes = torch.tensor(block_sizes_list, dtype=index_type) block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt, world_size, block_sizes, False, False, block_bucketize_row_pos + kjt=kjt, + num_buckets=world_size, + block_sizes=block_sizes, + block_bucketize_row_pos=block_bucketize_row_pos, ) expected_block_bucketized_kjt = block_bucketize_ref( diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 2d20bc116..cc0302470 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -250,6 +250,13 @@ def input_size(self) -> int: """ pass + @abc.abstractmethod + def buckets(self) -> int: + """ + Returns number of uniform buckets, relevant to resharding + """ + pass + @abc.abstractmethod def validate_state(self) -> None: """ @@ -975,6 +982,7 @@ def __init__( name: Optional[str] = None, output_global_offset: int = 0, # typically not provided by user output_segments: Optional[List[int]] = None, # typically not provided by user + buckets: int = 1, ) -> None: if output_segments is None: output_segments = [output_global_offset, output_global_offset + zch_size] @@ -1000,6 +1008,7 @@ def __init__( self._eviction_policy = eviction_policy self._current_iter: int = -1 + self._buckets = buckets self._init_buffers() ## ------ history info ------ @@ -1302,6 +1311,9 @@ def forward( def output_size(self) -> int: return self._zch_size + def buckets(self) -> int: + return self._buckets + def input_size(self) -> int: return self._input_hash_size @@ -1349,4 +1361,5 @@ def rebuild_with_output_id_range( input_hash_func=self._input_hash_func, output_global_offset=output_id_range[0], output_segments=output_segments, + buckets=len(output_segments) - 1, )