Skip to content

Commit 8ec0540

Browse files
aporialiaofacebook-github-bot
authored andcommitted
2/n Replace Input Generation + Documentation (#2875)
Summary: Pull Request resolved: #2875 * Using refactored `ModelInput` generation from D71556886 for dynamic sharding unit test. This will be important for CW sharding since existing input generation call here is insufficient for handling multiple shards per table. * Add additional docs explaining uses for some Dynamic Sharding API * Rename ModelInput kjt creation method to indicate static use outside of ModelInput class Reviewed By: TroyGarden Differential Revision: D72798357 fbshipit-source-id: bd6b57a12f212d731779a9cce8206a2347b966c3
1 parent d686bb1 commit 8ec0540

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

torchrec/distributed/embeddingbag.py

+6
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,12 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
855855
"""
856856
This provides consistency between this class and the EmbeddingBagCollection's
857857
nn.Module API calls (state_dict, named_modules, etc)
858+
859+
Args:
860+
skip_registering (bool): If True, skips registering state_dict hooks. This is useful
861+
for dynamic sharding where the state_dict hooks do not need to be
862+
reregistered when being resharded. Default is False.
863+
858864
"""
859865
self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
860866
for table_name in self._table_names:

torchrec/distributed/test_utils/test_input.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
5353
)
5454

5555
def record_stream(self, stream: torch.Stream) -> None:
56+
"""
57+
need to explicitly call `record_stream` for non-pytorch native object (KJT)
58+
"""
5659
self.float_features.record_stream(stream)
5760
if isinstance(self.idlist_features, KeyedJaggedTensor):
5861
self.idlist_features.record_stream(stream)
@@ -204,7 +207,7 @@ def generate_local_batches(
204207
all_zeros: bool = False,
205208
) -> List["ModelInput"]:
206209
"""
207-
Returns multi-rank batches of world_size
210+
Returns multi-rank batches (ModelInput) of world_size
208211
"""
209212
return [
210213
cls.generate(
@@ -255,15 +258,15 @@ def generate(
255258
all_zeros: bool = False,
256259
) -> "ModelInput":
257260
"""
258-
Returns a single batch
261+
Returns a single batch of `ModelInput`
259262
"""
260263
float_features = (
261264
torch.zeros((batch_size, num_float_features), device=device)
262265
if all_zeros
263266
else torch.rand((batch_size, num_float_features), device=device)
264267
)
265268
idlist_features = (
266-
ModelInput._create_standard_kjt(
269+
ModelInput.create_standard_kjt(
267270
batch_size=batch_size,
268271
tables=tables,
269272
pooling_avg=pooling_avg,
@@ -281,7 +284,7 @@ def generate(
281284
else None
282285
)
283286
idscore_features = (
284-
ModelInput._create_standard_kjt(
287+
ModelInput.create_standard_kjt(
285288
batch_size=batch_size,
286289
tables=weighted_tables,
287290
pooling_avg=pooling_avg,
@@ -324,6 +327,13 @@ def _create_features_lengths_indices(
324327
lengths_dtype: torch.dtype = torch.int64,
325328
all_zeros: bool = False,
326329
) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]:
330+
"""
331+
Create keys, lengths, and indices for a KeyedJaggedTensor from embedding table configs.
332+
333+
Returns:
334+
Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]:
335+
Feature names, per-feature lengths, and per-feature indices.
336+
"""
327337
pooling_factor_per_feature: List[int] = []
328338
num_embeddings_per_feature: List[int] = []
329339
max_length_per_feature: List[Optional[int]] = []
@@ -395,6 +405,14 @@ def _assemble_kjt(
395405
use_offsets: bool = False,
396406
offsets_dtype: torch.dtype = torch.int64,
397407
) -> KeyedJaggedTensor:
408+
"""
409+
410+
Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices.
411+
412+
This method is used to generate corresponding local_batches and global_batch KJTs.
413+
It concatenates the lengths and indices for each feature to form a complete KJT.
414+
"""
415+
398416
lengths = torch.cat(lengths_per_feature)
399417
indices = torch.cat(indices_per_feature)
400418
offsets = None
@@ -407,7 +425,7 @@ def _assemble_kjt(
407425
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)
408426

409427
@staticmethod
410-
def _create_standard_kjt(
428+
def create_standard_kjt(
411429
batch_size: int,
412430
tables: Union[
413431
List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig]
@@ -464,6 +482,10 @@ def _create_batched_standard_kjts(
464482
lengths_dtype: torch.dtype = torch.int64,
465483
all_zeros: bool = False,
466484
) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]:
485+
"""
486+
generate a global KJT and corresponding per-rank KJTs, the data are the same
487+
so that they can be used for result comparison.
488+
"""
467489
data_per_rank = [
468490
ModelInput._create_features_lengths_indices(
469491
batch_size,

torchrec/distributed/tests/test_dynamic_sharding.py

+10-45
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MultiProcessContext,
3636
MultiProcessTestBase,
3737
)
38+
from torchrec.distributed.test_utils.test_input import ModelInput
3839
from torchrec.distributed.test_utils.test_sharding import copy_state_dict
3940

4041
from torchrec.distributed.types import (
@@ -58,46 +59,6 @@ def feature_name(i: int) -> str:
5859
return "feature_" + str(i)
5960

6061

61-
def generate_input_by_world_size(
62-
world_size: int,
63-
num_tables: int,
64-
num_embeddings: int = 4,
65-
max_mul: int = 3,
66-
) -> List[KeyedJaggedTensor]:
67-
# TODO merge with new ModelInput generator in TestUtils
68-
kjt_input_per_rank = []
69-
mul = random.randint(1, max_mul)
70-
total_size = num_tables * mul
71-
72-
for _ in range(world_size):
73-
feature_names = [feature_name(i) for i in range(num_tables)]
74-
lengths = []
75-
values = []
76-
counting_l = 0
77-
for i in range(total_size):
78-
if i == total_size - 1:
79-
lengths.append(total_size - counting_l)
80-
break
81-
next_l = random.randint(0, total_size - counting_l)
82-
values.extend(
83-
[random.randint(0, num_embeddings - 1) for _ in range(next_l)]
84-
)
85-
lengths.append(next_l)
86-
counting_l += next_l
87-
88-
# for length in lengths:
89-
90-
kjt_input_per_rank.append(
91-
KeyedJaggedTensor.from_lengths_sync(
92-
keys=feature_names,
93-
values=torch.LongTensor(values),
94-
lengths=torch.LongTensor(lengths),
95-
)
96-
)
97-
98-
return kjt_input_per_rank
99-
100-
10162
def generate_embedding_bag_config(
10263
data_type: DataType,
10364
num_tables: int = 3,
@@ -372,9 +333,13 @@ def _run_ebc_resharding_test(
372333
):
373334
return
374335

375-
kjt_input_per_rank = generate_input_by_world_size(
376-
world_size, num_tables, num_embeddings
377-
)
336+
kjt_input_per_rank = [
337+
ModelInput.create_standard_kjt(
338+
batch_size=2,
339+
tables=embedding_bag_config,
340+
)
341+
for _ in range(world_size)
342+
]
378343

379344
# initial_state_dict filled with deterministic dummy values
380345
initial_state_dict = create_test_initial_state_dict(
@@ -418,8 +383,8 @@ def test_dynamic_sharding_ebc_tw(
418383
old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
419384
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
420385

421-
if new_ranks == old_ranks:
422-
return
386+
while new_ranks == old_ranks:
387+
new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)]
423388
per_param_sharding = {}
424389
new_per_param_sharding = {}
425390

0 commit comments

Comments
 (0)