@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187
187
return partitioning::stitch (&partitioning_ctx, block);
188
188
}
189
189
190
- void MapInputsAndDetermineDTypes (
190
+ ir::TypeMap MapInputsAndDetermineDTypes (
191
191
CompileSpec& cfg,
192
192
std::shared_ptr<torch::jit::Graph>& g,
193
193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197
197
cfg.partitioning_info .collection_input_spec_map =
198
198
ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
199
199
200
+ ir::TypeMap inferred_dtypes;
200
201
auto collection_inputs = ir::get_collection_inputs (g, static_params);
201
202
LOG_DEBUG (
202
203
" In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
218
219
LOG_INFO (
219
220
" Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
220
221
<< in->debugName () << " has type " << est_type_opt[i].value ());
221
- spec[i].dtype = util::ScalarTypeToTRTDataType ( est_type_opt[i].value () );
222
+ spec[i].dtype = est_type_opt[i].value ();
222
223
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
223
224
// If we cannot calculate the type and the user did not define the type, then default to FP32
224
225
LOG_WARNING (
225
226
" Cannot infer input type from calcuations in graph for input "
226
227
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
227
- spec[i].dtype = nvinfer1::DataType:: kFLOAT ;
228
+ spec[i].dtype = at:: kFloat ;
228
229
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
229
230
if (!est_type_opt[i]) {
230
231
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
@@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
236
237
auto warn_str = ss.str ();
237
238
LOG_WARNING (warn_str);
238
239
// Overwrite type map with user settings
239
- first_use_type_map[in][i] = {
240
- util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
241
-
242
- } else {
243
- if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
244
- est_type_opt[i].value ()) {
245
- std::stringstream ss;
246
- ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
247
- ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
248
- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
249
- ss << est_type_opt[i].value () << std::endl;
250
- ss << " The compiler is going to use the user setting "
251
- << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
252
- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
253
- ss << " compatibility with PyTorch's data type convention is required.\n " ;
254
- ss << " If you do indeed see errors at runtime either:\n " ;
255
- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
256
- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
257
- auto warn_str = ss.str ();
258
- LOG_WARNING (warn_str);
259
- // Overwrite type map with user settings
260
- first_use_type_map[in][i] = {
261
- util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
262
- }
240
+ first_use_type_map[in][i] = {cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype };
241
+
242
+ } else if (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype != est_type_opt[i].value ()) {
243
+ std::stringstream ss;
244
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
245
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
246
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
247
+ ss << est_type_opt[i].value () << std::endl;
248
+ ss << " The compiler is going to use the user setting "
249
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
250
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
251
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
252
+ ss << " If you do indeed see errors at runtime either:\n " ;
253
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
254
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
255
+ auto warn_str = ss.str ();
256
+ LOG_WARNING (warn_str);
257
+ // Overwrite type map with user settings
258
+ first_use_type_map[in][i] = {cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype };
263
259
}
264
260
} else {
265
261
// The user defined the type so no changes are necessary
266
262
}
263
+
264
+ // Insert entry for Value pointer and determined ScalarType
265
+ inferred_dtypes.insert ({in, {spec[i].dtype }});
267
266
}
268
267
}
269
- // }
268
+ return inferred_dtypes;
270
269
}
271
270
272
271
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
284
283
285
284
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
286
285
286
+ // Ensure none of the specified types are of acceptable input types incompatible with TRT
287
+ // Currently, only at::kLong is an acceptable, though TRT-incompatible type
288
+ for (auto value_to_dtypes : first_use_types) {
289
+ for (auto dtype : value_to_dtypes.second ) {
290
+ TORCHTRT_CHECK (
291
+ !dtype || dtype.value () != at::kLong , " Cannot specify Int64 input for a model fully compiled in TRT" );
292
+ }
293
+ }
294
+
287
295
auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
288
296
289
297
return engine;
@@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307
315
// Infer the type of an input from the weights of the calculation
308
316
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
309
317
310
- MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
318
+ // Extract map of IValue to DType
319
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
320
+
321
+ // Check whether any of the input types are Long
322
+ bool user_requested_long = false ;
323
+ for (auto dtype : type_map) {
324
+ user_requested_long |= dtype.second && (dtype.second .value () == at::kLong );
325
+ }
326
+
327
+ // Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
328
+ if (cfg.partitioning_info .enabled && cfg.partitioning_info .truncate_long_and_double && user_requested_long) {
329
+ auto casts_inserted = lowering::AutocastLongInputs (g, type_map, cfg.lower_info .getGPUDeviceString ());
330
+ user_requested_long &= (casts_inserted > 0 );
331
+ }
332
+
311
333
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
312
334
auto outputIsCollection = conversion::OutputIsCollection (g->block ());
313
- if (cfg.partitioning_info .enabled &&
335
+ if (cfg.partitioning_info .enabled && !user_requested_long &&
314
336
(cfg.lower_info .forced_fallback_modules .size () == 0 &&
315
337
cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
316
338
!outputIsCollection) {
@@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
320
342
if (cfg.partitioning_info .enabled &&
321
343
(!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
322
344
cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
323
- outputIsCollection)) {
345
+ outputIsCollection || user_requested_long )) {
324
346
auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
325
347
new_g = graph_and_mapping.first ;
326
348
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
0 commit comments