@@ -120,7 +120,6 @@ def __call__(
120
120
indices_dtype : torch .dtype = torch .int64 ,
121
121
offsets_dtype : torch .dtype = torch .int64 ,
122
122
lengths_dtype : torch .dtype = torch .int64 ,
123
- long_indices : bool = True ,
124
123
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
125
124
126
125
@@ -166,7 +165,6 @@ def gen_model_and_input(
166
165
global_constant_batch : bool = False ,
167
166
num_inputs : int = 1 ,
168
167
input_type : str = "kjt" , # "kjt" or "td"
169
- long_indices : bool = True ,
170
168
) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
171
169
torch .manual_seed (0 )
172
170
if dedup_feature_names :
@@ -229,7 +227,6 @@ def gen_model_and_input(
229
227
indices_dtype = indices_dtype ,
230
228
offsets_dtype = offsets_dtype ,
231
229
lengths_dtype = lengths_dtype ,
232
- long_indices = long_indices ,
233
230
)
234
231
)
235
232
else :
@@ -247,7 +244,6 @@ def gen_model_and_input(
247
244
indices_dtype = indices_dtype ,
248
245
offsets_dtype = offsets_dtype ,
249
246
lengths_dtype = lengths_dtype ,
250
- long_indices = long_indices ,
251
247
)
252
248
)
253
249
return (model , inputs )
0 commit comments