Skip to content

Commit b9cb159

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Re-shardable Hash Zch (#2538)
Summary: Fully reshardable ZCH: we can handle any value which is in common with default value (768) so WS 1,2,4,8,16,24,32,48,64,96,128, etc. and go up and down. Differential Revision: D65912888
1 parent be4b9c7 commit b9cb159

File tree

5 files changed

+85
-11
lines changed

5 files changed

+85
-11
lines changed

torchrec/distributed/embedding_sharding.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def bucketize_kjt_before_all2all(
204204
kjt: KeyedJaggedTensor,
205205
num_buckets: int,
206206
block_sizes: torch.Tensor,
207+
total_num_blocks: Optional[torch.Tensor] = None,
207208
output_permute: bool = False,
208209
bucketize_pos: bool = False,
209210
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
@@ -219,6 +220,7 @@ def bucketize_kjt_before_all2all(
219220
Args:
220221
num_buckets (int): number of buckets to bucketize the values into.
221222
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
223+
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
222224
output_permute (bool): output the memory location mapping from the unbucketized
223225
values to bucketized values or not.
224226
bucketize_pos (bool): output the changed position of the bucketized values or
@@ -235,7 +237,7 @@ def bucketize_kjt_before_all2all(
235237
block_sizes.numel() == num_features,
236238
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
237239
)
238-
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
240+
239241
(
240242
bucketized_lengths,
241243
bucketized_indices,
@@ -247,14 +249,24 @@ def bucketize_kjt_before_all2all(
247249
kjt.values(),
248250
bucketize_pos=bucketize_pos,
249251
sequence=output_permute,
250-
block_sizes=block_sizes_new_type,
252+
block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()),
253+
total_num_blocks=(
254+
_fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values())
255+
if total_num_blocks is not None
256+
else None
257+
),
251258
my_size=num_buckets,
252259
weights=kjt.weights_or_none(),
253260
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
254261
max_B=_fx_wrap_max_B(kjt),
255-
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
262+
block_bucketize_pos=(
263+
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
264+
if block_bucketize_row_pos is not None
265+
else None
266+
),
256267
keep_orig_idx=keep_original_indices,
257268
)
269+
258270
return (
259271
KeyedJaggedTensor(
260272
# duplicate keys will be resolved by AllToAll

torchrec/distributed/mc_modules.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,35 @@ def _create_input_dists(
389389
input_feature_names: List[str],
390390
) -> None:
391391
for sharding, sharding_features in zip(
392-
self._embedding_shardings, self._sharding_features
392+
self._embedding_shardings,
393+
self._sharding_features,
393394
):
394395
assert isinstance(sharding, BaseRwEmbeddingSharding)
395-
feature_hash_sizes: List[int] = [
396+
feature_num_buckets: List[int] = [
397+
self._managed_collision_modules[self._feature_to_table[f]].buckets()
398+
for f in sharding_features
399+
]
400+
401+
input_sizes: List[int] = [
396402
self._managed_collision_modules[self._feature_to_table[f]].input_size()
397403
for f in sharding_features
398404
]
399405

406+
feature_hash_sizes: List[int] = []
407+
feature_total_num_buckets: List[int] = []
408+
for input_size, num_buckets in zip(
409+
input_sizes,
410+
feature_num_buckets,
411+
):
412+
feature_hash_sizes.append(input_size)
413+
feature_total_num_buckets.append(num_buckets)
414+
400415
input_dist = RwSparseFeaturesDist(
401416
# pyre-ignore [6]
402417
pg=sharding._pg,
403418
num_features=sharding._get_num_features(),
404419
feature_hash_sizes=feature_hash_sizes,
420+
feature_total_num_buckets=feature_total_num_buckets,
405421
device=sharding._device,
406422
is_sequence=True,
407423
has_feature_processor=sharding._has_feature_processor,

torchrec/distributed/sharding/rw_sharding.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
279279
communication.
280280
num_features (int): total number of features.
281281
feature_hash_sizes (List[int]): hash sizes of features.
282+
feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size.
282283
device (Optional[torch.device]): device on which buffers will be allocated.
283284
is_sequence (bool): if this is for a sequence embedding.
284285
has_feature_processor (bool): existence of feature processor (ie. position
@@ -291,6 +292,7 @@ def __init__(
291292
pg: dist.ProcessGroup,
292293
num_features: int,
293294
feature_hash_sizes: List[int],
295+
feature_total_num_buckets: Optional[List[int]] = None,
294296
device: Optional[torch.device] = None,
295297
is_sequence: bool = False,
296298
has_feature_processor: bool = False,
@@ -300,18 +302,39 @@ def __init__(
300302
super().__init__()
301303
self._world_size: int = pg.size()
302304
self._num_features = num_features
303-
feature_block_sizes = [
304-
(hash_size + self._world_size - 1) // self._world_size
305-
for hash_size in feature_hash_sizes
306-
]
305+
306+
feature_block_sizes: List[int] = []
307+
308+
for i, hash_size in enumerate(feature_hash_sizes):
309+
block_divisor = self._world_size
310+
if feature_total_num_buckets is not None:
311+
assert feature_total_num_buckets[i] % self._world_size == 0
312+
block_divisor = feature_total_num_buckets[i]
313+
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)
314+
307315
self.register_buffer(
308316
"_feature_block_sizes_tensor",
309317
torch.tensor(
310318
feature_block_sizes,
311319
device=device,
312320
dtype=torch.int64,
313321
),
322+
persistent=False,
314323
)
324+
self._has_multiple_blocks_per_shard: bool = (
325+
feature_total_num_buckets is not None
326+
)
327+
if self._has_multiple_blocks_per_shard:
328+
self.register_buffer(
329+
"_feature_total_num_blocks_tensor",
330+
torch.tensor(
331+
[feature_total_num_buckets],
332+
device=device,
333+
dtype=torch.int64,
334+
),
335+
persistent=False,
336+
)
337+
315338
self._dist = KJTAllToAll(
316339
pg=pg,
317340
splits=[self._num_features] * self._world_size,
@@ -345,6 +368,11 @@ def forward(
345368
sparse_features,
346369
num_buckets=self._world_size,
347370
block_sizes=self._feature_block_sizes_tensor,
371+
total_num_blocks=(
372+
self._feature_total_num_blocks_tensor
373+
if self._has_multiple_blocks_per_shard
374+
else None
375+
),
348376
output_permute=self._is_sequence,
349377
bucketize_pos=(
350378
self._has_feature_processor

torchrec/distributed/tests/test_utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def test_kjt_bucketize_before_all2all(
336336
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()
337337

338338
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
339-
kjt, world_size, block_sizes, False, False
339+
kjt=kjt,
340+
num_buckets=world_size,
341+
block_sizes=block_sizes,
340342
)
341343

342344
expected_block_bucketized_kjt = block_bucketize_ref(
@@ -433,7 +435,10 @@ def test_kjt_bucketize_before_all2all_cpu(
433435
"""
434436
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
435437
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
436-
kjt, world_size, block_sizes, False, False, block_bucketize_row_pos
438+
kjt=kjt,
439+
num_buckets=world_size,
440+
block_sizes=block_sizes,
441+
block_bucketize_row_pos=block_bucketize_row_pos,
437442
)
438443

439444
expected_block_bucketized_kjt = block_bucketize_ref(

torchrec/modules/mc_modules.py

+13
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ def input_size(self) -> int:
250250
"""
251251
pass
252252

253+
@abc.abstractmethod
254+
def buckets(self) -> int:
255+
"""
256+
Returns number of uniform buckets, relevant to resharding
257+
"""
258+
pass
259+
253260
@abc.abstractmethod
254261
def validate_state(self) -> None:
255262
"""
@@ -975,6 +982,7 @@ def __init__(
975982
name: Optional[str] = None,
976983
output_global_offset: int = 0, # typically not provided by user
977984
output_segments: Optional[List[int]] = None, # typically not provided by user
985+
buckets: int = 1,
978986
) -> None:
979987
if output_segments is None:
980988
output_segments = [output_global_offset, output_global_offset + zch_size]
@@ -1000,6 +1008,7 @@ def __init__(
10001008
self._eviction_policy = eviction_policy
10011009

10021010
self._current_iter: int = -1
1011+
self._buckets = buckets
10031012
self._init_buffers()
10041013

10051014
## ------ history info ------
@@ -1302,6 +1311,9 @@ def forward(
13021311
def output_size(self) -> int:
13031312
return self._zch_size
13041313

1314+
def buckets(self) -> int:
1315+
return self._buckets
1316+
13051317
def input_size(self) -> int:
13061318
return self._input_hash_size
13071319

@@ -1349,4 +1361,5 @@ def rebuild_with_output_id_range(
13491361
input_hash_func=self._input_hash_func,
13501362
output_global_offset=output_id_range[0],
13511363
output_segments=output_segments,
1364+
buckets=len(output_segments) - 1,
13521365
)

0 commit comments

Comments
 (0)