@@ -279,6 +279,7 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
279
279
communication.
280
280
num_features (int): total number of features.
281
281
feature_hash_sizes (List[int]): hash sizes of features.
282
+ feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size.
282
283
device (Optional[torch.device]): device on which buffers will be allocated.
283
284
is_sequence (bool): if this is for a sequence embedding.
284
285
has_feature_processor (bool): existence of feature processor (ie. position
@@ -291,6 +292,7 @@ def __init__(
291
292
pg : dist .ProcessGroup ,
292
293
num_features : int ,
293
294
feature_hash_sizes : List [int ],
295
+ feature_total_num_buckets : Optional [List [int ]] = None ,
294
296
device : Optional [torch .device ] = None ,
295
297
is_sequence : bool = False ,
296
298
has_feature_processor : bool = False ,
@@ -300,18 +302,39 @@ def __init__(
300
302
super ().__init__ ()
301
303
self ._world_size : int = pg .size ()
302
304
self ._num_features = num_features
303
- feature_block_sizes = [
304
- (hash_size + self ._world_size - 1 ) // self ._world_size
305
- for hash_size in feature_hash_sizes
306
- ]
305
+
306
+ feature_block_sizes : List [int ] = []
307
+
308
+ for i , hash_size in enumerate (feature_hash_sizes ):
309
+ block_divisor = self ._world_size
310
+ if feature_total_num_buckets is not None :
311
+ assert feature_total_num_buckets [i ] % self ._world_size == 0
312
+ block_divisor = feature_total_num_buckets [i ]
313
+ feature_block_sizes .append ((hash_size + block_divisor - 1 ) // block_divisor )
314
+
307
315
self .register_buffer (
308
316
"_feature_block_sizes_tensor" ,
309
317
torch .tensor (
310
318
feature_block_sizes ,
311
319
device = device ,
312
320
dtype = torch .int64 ,
313
321
),
322
+ persistent = False ,
314
323
)
324
+ self ._has_multiple_blocks_per_shard : bool = (
325
+ feature_total_num_buckets is not None
326
+ )
327
+ if self ._has_multiple_blocks_per_shard :
328
+ self .register_buffer (
329
+ "_feature_total_num_blocks_tensor" ,
330
+ torch .tensor (
331
+ [feature_total_num_buckets ],
332
+ device = device ,
333
+ dtype = torch .int64 ,
334
+ ),
335
+ persistent = False ,
336
+ )
337
+
315
338
self ._dist = KJTAllToAll (
316
339
pg = pg ,
317
340
splits = [self ._num_features ] * self ._world_size ,
@@ -345,6 +368,11 @@ def forward(
345
368
sparse_features ,
346
369
num_buckets = self ._world_size ,
347
370
block_sizes = self ._feature_block_sizes_tensor ,
371
+ total_num_blocks = (
372
+ self ._feature_total_num_blocks_tensor
373
+ if self ._has_multiple_blocks_per_shard
374
+ else None
375
+ ),
348
376
output_permute = self ._is_sequence ,
349
377
bucketize_pos = (
350
378
self ._has_feature_processor
0 commit comments