@@ -53,6 +53,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
53
53
)
54
54
55
55
def record_stream (self , stream : torch .Stream ) -> None :
56
+ """
57
+ need to explicitly call `record_stream` for non-pytorch native object (KJT)
58
+ """
56
59
self .float_features .record_stream (stream )
57
60
if isinstance (self .idlist_features , KeyedJaggedTensor ):
58
61
self .idlist_features .record_stream (stream )
@@ -204,7 +207,7 @@ def generate_local_batches(
204
207
all_zeros : bool = False ,
205
208
) -> List ["ModelInput" ]:
206
209
"""
207
- Returns multi-rank batches of world_size
210
+ Returns multi-rank batches (ModelInput) of world_size
208
211
"""
209
212
return [
210
213
cls .generate (
@@ -255,15 +258,15 @@ def generate(
255
258
all_zeros : bool = False ,
256
259
) -> "ModelInput" :
257
260
"""
258
- Returns a single batch
261
+ Returns a single batch of `ModelInput`
259
262
"""
260
263
float_features = (
261
264
torch .zeros ((batch_size , num_float_features ), device = device )
262
265
if all_zeros
263
266
else torch .rand ((batch_size , num_float_features ), device = device )
264
267
)
265
268
idlist_features = (
266
- ModelInput ._create_standard_kjt (
269
+ ModelInput .create_standard_kjt (
267
270
batch_size = batch_size ,
268
271
tables = tables ,
269
272
pooling_avg = pooling_avg ,
@@ -281,7 +284,7 @@ def generate(
281
284
else None
282
285
)
283
286
idscore_features = (
284
- ModelInput ._create_standard_kjt (
287
+ ModelInput .create_standard_kjt (
285
288
batch_size = batch_size ,
286
289
tables = weighted_tables ,
287
290
pooling_avg = pooling_avg ,
@@ -324,6 +327,13 @@ def _create_features_lengths_indices(
324
327
lengths_dtype : torch .dtype = torch .int64 ,
325
328
all_zeros : bool = False ,
326
329
) -> Tuple [List [str ], List [torch .Tensor ], List [torch .Tensor ]]:
330
+ """
331
+ Create keys, lengths, and indices for a KeyedJaggedTensor from embedding table configs.
332
+
333
+ Returns:
334
+ Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]:
335
+ Feature names, per-feature lengths, and per-feature indices.
336
+ """
327
337
pooling_factor_per_feature : List [int ] = []
328
338
num_embeddings_per_feature : List [int ] = []
329
339
max_length_per_feature : List [Optional [int ]] = []
@@ -395,6 +405,14 @@ def _assemble_kjt(
395
405
use_offsets : bool = False ,
396
406
offsets_dtype : torch .dtype = torch .int64 ,
397
407
) -> KeyedJaggedTensor :
408
+ """
409
+
410
+ Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices.
411
+
412
+ This method is used to generate corresponding local_batches and global_batch KJTs.
413
+ It concatenates the lengths and indices for each feature to form a complete KJT.
414
+ """
415
+
398
416
lengths = torch .cat (lengths_per_feature )
399
417
indices = torch .cat (indices_per_feature )
400
418
offsets = None
@@ -407,7 +425,7 @@ def _assemble_kjt(
407
425
return KeyedJaggedTensor (features , indices , weights , lengths , offsets )
408
426
409
427
@staticmethod
410
- def _create_standard_kjt (
428
+ def create_standard_kjt (
411
429
batch_size : int ,
412
430
tables : Union [
413
431
List [EmbeddingTableConfig ], List [EmbeddingBagConfig ], List [EmbeddingConfig ]
@@ -464,6 +482,10 @@ def _create_batched_standard_kjts(
464
482
lengths_dtype : torch .dtype = torch .int64 ,
465
483
all_zeros : bool = False ,
466
484
) -> Tuple [KeyedJaggedTensor , List [KeyedJaggedTensor ]]:
485
+ """
486
+ generate a global KJT and corresponding per-rank KJTs, the data are the same
487
+ so that they can be used for result comparison.
488
+ """
467
489
data_per_rank = [
468
490
ModelInput ._create_features_lengths_indices (
469
491
batch_size ,
0 commit comments