Skip to content

Commit d686bb1

Browse files
basilwongfacebook-github-bot
authored andcommitted
Remove long_indices argument from model.generate function (#2874)
Summary: Pull Request resolved: #2874 Related to stack starting with: D66728661 Removing long_indices as an input for more precise input values: ``` use_offsets: bool = False, indices_dtype: torch.dtype = torch.int64, offsets_dtype: torch.dtype = torch.int64, lengths_dtype: torch.dtype = torch.int64, ``` Reviewed By: TroyGarden Differential Revision: D69883857 fbshipit-source-id: aeb6b2590af4ddd9bef85fa629b6407bfbb6567b
1 parent f68bc72 commit d686bb1

File tree

6 files changed

+14
-17
lines changed

6 files changed

+14
-17
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,9 @@ def get_inputs(
372372
num_float_features=0,
373373
tables=tables,
374374
weighted_tables=[],
375-
long_indices=False,
376375
tables_pooling=pooling_configs,
376+
indices_dtype=torch.int32,
377+
lengths_dtype=torch.int32,
377378
)
378379

379380
if train:

torchrec/distributed/test_utils/infer_utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,12 @@ def prep_inputs(
187187
long_indices: bool = True,
188188
) -> List[ModelInput]:
189189
inputs = []
190+
if long_indices:
191+
indices_dtype = torch.int64
192+
lengths_dtype = torch.int64
193+
else:
194+
indices_dtype = torch.int32
195+
lengths_dtype = torch.int32
190196
for _ in range(count):
191197
inputs.append(
192198
ModelInput.generate(
@@ -195,7 +201,8 @@ def prep_inputs(
195201
num_float_features=model_info.num_float_features,
196202
tables=model_info.tables,
197203
weighted_tables=model_info.weighted_tables,
198-
long_indices=long_indices,
204+
indices_dtype=indices_dtype,
205+
lengths_dtype=lengths_dtype,
199206
)[1][0],
200207
)
201208

torchrec/distributed/test_utils/test_model.py

-9
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,11 @@ def generate(
9898
indices_dtype: torch.dtype = torch.int64,
9999
offsets_dtype: torch.dtype = torch.int64,
100100
lengths_dtype: torch.dtype = torch.int64,
101-
long_indices: bool = True, # TODO - remove this once code base is updated to support more than long_indices spec
102101
) -> Tuple["ModelInput", List["ModelInput"]]:
103102
"""
104103
Returns a global (single-rank training) batch
105104
and a list of local (multi-rank training) batches of world_size.
106105
"""
107-
if long_indices:
108-
indices_dtype = torch.int64
109-
lengths_dtype = torch.int64
110-
use_offsets = False
111-
else:
112-
indices_dtype = torch.int32
113-
lengths_dtype = torch.int32
114-
use_offsets = False
115106
batch_size_by_rank = [batch_size] * world_size
116107
if variable_batch_size:
117108
batch_size_by_rank = [

torchrec/distributed/test_utils/test_model_parallel_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def _test_sharded_forward(
113113
dense_device=cuda_device,
114114
sparse_device=cuda_device,
115115
generate=generate,
116-
long_indices=False,
116+
indices_dtype=torch.int32,
117+
lengths_dtype=torch.int32,
117118
)
118119
global_model = quantize_callable(global_model, **quantize_callable_kwargs)
119120
local_input = _inputs[0][1][default_rank].to(cuda_device)

torchrec/distributed/test_utils/test_sharding.py

-4
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __call__(
120120
indices_dtype: torch.dtype = torch.int64,
121121
offsets_dtype: torch.dtype = torch.int64,
122122
lengths_dtype: torch.dtype = torch.int64,
123-
long_indices: bool = True,
124123
) -> Tuple["ModelInput", List["ModelInput"]]: ...
125124

126125

@@ -166,7 +165,6 @@ def gen_model_and_input(
166165
global_constant_batch: bool = False,
167166
num_inputs: int = 1,
168167
input_type: str = "kjt", # "kjt" or "td"
169-
long_indices: bool = True,
170168
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
171169
torch.manual_seed(0)
172170
if dedup_feature_names:
@@ -229,7 +227,6 @@ def gen_model_and_input(
229227
indices_dtype=indices_dtype,
230228
offsets_dtype=offsets_dtype,
231229
lengths_dtype=lengths_dtype,
232-
long_indices=long_indices,
233230
)
234231
)
235232
else:
@@ -247,7 +244,6 @@ def gen_model_and_input(
247244
indices_dtype=indices_dtype,
248245
offsets_dtype=offsets_dtype,
249246
lengths_dtype=lengths_dtype,
250-
long_indices=long_indices,
251247
)
252248
)
253249
return (model, inputs)

torchrec/distributed/tests/test_quant_sequence_model_parallel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def test_quant_pred_shard(
198198
num_float_features=10,
199199
tables=self.tables,
200200
weighted_tables=[],
201-
long_indices=False,
201+
indices_dtype=torch.int32,
202+
lengths_dtype=torch.int32,
202203
)
203204
local_batch = local_batch.to(device)
204205
sharded_quant_model(local_batch.idlist_features)

0 commit comments

Comments
 (0)