Skip to content

Commit 7ae70cd

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Re-shardable Hash Zch (#2556)
Summary: Pull Request resolved: #2556 Pull Request resolved: #2538 Fully reshardable ZCH: we can handle any value which is in common with default value (768); this allows training or inference world size to be any factor of `total_num_buckets` (ie. 1,2,4,8,16,24,32,48,64,96,128, etc) and still have identical numerics. Reviewed By: emlin Differential Revision: D65912888 fbshipit-source-id: 55f68903ab0a7007e2dc7f2d22ce08613f5abb48
1 parent c57d124 commit 7ae70cd

File tree

5 files changed

+85
-11
lines changed

5 files changed

+85
-11
lines changed

torchrec/distributed/embedding_sharding.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def bucketize_kjt_before_all2all(
204204
kjt: KeyedJaggedTensor,
205205
num_buckets: int,
206206
block_sizes: torch.Tensor,
207+
total_num_blocks: Optional[torch.Tensor] = None,
207208
output_permute: bool = False,
208209
bucketize_pos: bool = False,
209210
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
@@ -219,6 +220,7 @@ def bucketize_kjt_before_all2all(
219220
Args:
220221
num_buckets (int): number of buckets to bucketize the values into.
221222
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
223+
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
222224
output_permute (bool): output the memory location mapping from the unbucketized
223225
values to bucketized values or not.
224226
bucketize_pos (bool): output the changed position of the bucketized values or
@@ -235,7 +237,7 @@ def bucketize_kjt_before_all2all(
235237
block_sizes.numel() == num_features,
236238
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
237239
)
238-
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
240+
239241
(
240242
bucketized_lengths,
241243
bucketized_indices,
@@ -247,14 +249,24 @@ def bucketize_kjt_before_all2all(
247249
kjt.values(),
248250
bucketize_pos=bucketize_pos,
249251
sequence=output_permute,
250-
block_sizes=block_sizes_new_type,
252+
block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()),
253+
total_num_blocks=(
254+
_fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values())
255+
if total_num_blocks is not None
256+
else None
257+
),
251258
my_size=num_buckets,
252259
weights=kjt.weights_or_none(),
253260
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
254261
max_B=_fx_wrap_max_B(kjt),
255-
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
262+
block_bucketize_pos=(
263+
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
264+
if block_bucketize_row_pos is not None
265+
else None
266+
),
256267
keep_orig_idx=keep_original_indices,
257268
)
269+
258270
return (
259271
KeyedJaggedTensor(
260272
# duplicate keys will be resolved by AllToAll

torchrec/distributed/mc_modules.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,35 @@ def _create_input_dists(
389389
input_feature_names: List[str],
390390
) -> None:
391391
for sharding, sharding_features in zip(
392-
self._embedding_shardings, self._sharding_features
392+
self._embedding_shardings,
393+
self._sharding_features,
393394
):
394395
assert isinstance(sharding, BaseRwEmbeddingSharding)
395-
feature_hash_sizes: List[int] = [
396+
feature_num_buckets: List[int] = [
397+
self._managed_collision_modules[self._feature_to_table[f]].buckets()
398+
for f in sharding_features
399+
]
400+
401+
input_sizes: List[int] = [
396402
self._managed_collision_modules[self._feature_to_table[f]].input_size()
397403
for f in sharding_features
398404
]
399405

406+
feature_hash_sizes: List[int] = []
407+
feature_total_num_buckets: List[int] = []
408+
for input_size, num_buckets in zip(
409+
input_sizes,
410+
feature_num_buckets,
411+
):
412+
feature_hash_sizes.append(input_size)
413+
feature_total_num_buckets.append(num_buckets)
414+
400415
input_dist = RwSparseFeaturesDist(
401416
# pyre-ignore [6]
402417
pg=sharding._pg,
403418
num_features=sharding._get_num_features(),
404419
feature_hash_sizes=feature_hash_sizes,
420+
feature_total_num_buckets=feature_total_num_buckets,
405421
device=sharding._device,
406422
is_sequence=True,
407423
has_feature_processor=sharding._has_feature_processor,

torchrec/distributed/sharding/rw_sharding.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
287287
communication.
288288
num_features (int): total number of features.
289289
feature_hash_sizes (List[int]): hash sizes of features.
290+
feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size.
290291
device (Optional[torch.device]): device on which buffers will be allocated.
291292
is_sequence (bool): if this is for a sequence embedding.
292293
has_feature_processor (bool): existence of feature processor (ie. position
@@ -299,6 +300,7 @@ def __init__(
299300
pg: dist.ProcessGroup,
300301
num_features: int,
301302
feature_hash_sizes: List[int],
303+
feature_total_num_buckets: Optional[List[int]] = None,
302304
device: Optional[torch.device] = None,
303305
is_sequence: bool = False,
304306
has_feature_processor: bool = False,
@@ -308,18 +310,39 @@ def __init__(
308310
super().__init__()
309311
self._world_size: int = pg.size()
310312
self._num_features = num_features
311-
feature_block_sizes = [
312-
(hash_size + self._world_size - 1) // self._world_size
313-
for hash_size in feature_hash_sizes
314-
]
313+
314+
feature_block_sizes: List[int] = []
315+
316+
for i, hash_size in enumerate(feature_hash_sizes):
317+
block_divisor = self._world_size
318+
if feature_total_num_buckets is not None:
319+
assert feature_total_num_buckets[i] % self._world_size == 0
320+
block_divisor = feature_total_num_buckets[i]
321+
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)
322+
315323
self.register_buffer(
316324
"_feature_block_sizes_tensor",
317325
torch.tensor(
318326
feature_block_sizes,
319327
device=device,
320328
dtype=torch.int64,
321329
),
330+
persistent=False,
322331
)
332+
self._has_multiple_blocks_per_shard: bool = (
333+
feature_total_num_buckets is not None
334+
)
335+
if self._has_multiple_blocks_per_shard:
336+
self.register_buffer(
337+
"_feature_total_num_blocks_tensor",
338+
torch.tensor(
339+
[feature_total_num_buckets],
340+
device=device,
341+
dtype=torch.int64,
342+
),
343+
persistent=False,
344+
)
345+
323346
self._dist = KJTAllToAll(
324347
pg=pg,
325348
splits=[self._num_features] * self._world_size,
@@ -353,6 +376,11 @@ def forward(
353376
sparse_features,
354377
num_buckets=self._world_size,
355378
block_sizes=self._feature_block_sizes_tensor,
379+
total_num_blocks=(
380+
self._feature_total_num_blocks_tensor
381+
if self._has_multiple_blocks_per_shard
382+
else None
383+
),
356384
output_permute=self._is_sequence,
357385
bucketize_pos=(
358386
self._has_feature_processor

torchrec/distributed/tests/test_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def test_kjt_bucketize_before_all2all(
336336
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()
337337

338338
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
339-
kjt, world_size, block_sizes, False, False
339+
kjt=kjt,
340+
num_buckets=world_size,
341+
block_sizes=block_sizes,
340342
)
341343

342344
expected_block_bucketized_kjt = block_bucketize_ref(
@@ -433,7 +435,10 @@ def test_kjt_bucketize_before_all2all_cpu(
433435
"""
434436
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
435437
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
436-
kjt, world_size, block_sizes, False, False, block_bucketize_row_pos
438+
kjt=kjt,
439+
num_buckets=world_size,
440+
block_sizes=block_sizes,
441+
block_bucketize_row_pos=block_bucketize_row_pos,
437442
)
438443

439444
expected_block_bucketized_kjt = block_bucketize_ref(

torchrec/modules/mc_modules.py

+13
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ def input_size(self) -> int:
250250
"""
251251
pass
252252

253+
@abc.abstractmethod
254+
def buckets(self) -> int:
255+
"""
256+
Returns number of uniform buckets, relevant to resharding
257+
"""
258+
pass
259+
253260
@abc.abstractmethod
254261
def validate_state(self) -> None:
255262
"""
@@ -975,6 +982,7 @@ def __init__(
975982
name: Optional[str] = None,
976983
output_global_offset: int = 0, # typically not provided by user
977984
output_segments: Optional[List[int]] = None, # typically not provided by user
985+
buckets: int = 1,
978986
) -> None:
979987
if output_segments is None:
980988
output_segments = [output_global_offset, output_global_offset + zch_size]
@@ -1000,6 +1008,7 @@ def __init__(
10001008
self._eviction_policy = eviction_policy
10011009

10021010
self._current_iter: int = -1
1011+
self._buckets = buckets
10031012
self._init_buffers()
10041013

10051014
## ------ history info ------
@@ -1320,6 +1329,9 @@ def forward(
13201329
def output_size(self) -> int:
13211330
return self._zch_size
13221331

1332+
def buckets(self) -> int:
1333+
return self._buckets
1334+
13231335
def input_size(self) -> int:
13241336
return self._input_hash_size
13251337

@@ -1376,4 +1388,5 @@ def rebuild_with_output_id_range(
13761388
input_hash_func=self._input_hash_func,
13771389
output_global_offset=output_id_range[0],
13781390
output_segments=output_segments,
1391+
buckets=len(output_segments) - 1,
13791392
)

0 commit comments

Comments
 (0)