@@ -308,70 +308,78 @@ void MapInputsAndDetermineDTypes(
308
308
std::shared_ptr<torch::jit::Graph>& g,
309
309
ir::StaticParams& static_params,
310
310
ir::CollectionTypeMap& first_use_type_map) {
311
- cfg.convert_info .collection_input_spec_map = std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
311
+ cfg.convert_info .collection_input_spec_map =
312
+ std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
312
313
313
- auto collection_inputs = ir::get_collection_inputs (g, static_params);
314
- LOG_DEBUG (" In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
314
+ auto collection_inputs = ir::get_collection_inputs (g, static_params);
315
+ LOG_DEBUG (
316
+ " In MapInputsAndDetermineDTypes, the g->inputs() size is "
317
+ << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
315
318
316
- for (auto in : collection_inputs) {
317
- std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
318
- std::vector<c10::optional<at::ScalarType>> est_type_opt;
319
+ for (auto in : collection_inputs) {
320
+ std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
321
+ std::vector<c10::optional<at::ScalarType>> est_type_opt;
319
322
320
- auto est_it = first_use_type_map.find (in);
321
- if (est_it != first_use_type_map.end ()) {
322
- est_type_opt = first_use_type_map.find (in)->second ;
323
- }
324
- // traverse elements in est_type_out and spec
325
- for (size_t i = 0 ; i < est_type_opt.size (); i++) {
326
- if (est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
327
- // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
328
- // type
329
- LOG_INFO (
330
- " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
331
- << in->debugName () << " has type " << est_type_opt[i].value ());
332
- spec[i].dtype = util::ScalarTypeToTRTDataType (est_type_opt[i].value ());
333
- } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
334
- // If we cannot calculate the type and the user did not define the type, then default to FP32
335
- LOG_WARNING (
336
- " Cannot infer input type from calcuations in graph for input "
337
- << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
338
- spec[i].dtype = nvinfer1::DataType::kFLOAT ;
339
- } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
340
- if (!est_type_opt[i]) {
341
- LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
323
+ auto est_it = first_use_type_map.find (in);
324
+ if (est_it != first_use_type_map.end ()) {
325
+ est_type_opt = first_use_type_map.find (in)->second ;
326
+ }
327
+ // traverse elements in est_type_out and spec
328
+ for (size_t i = 0 ; i < est_type_opt.size (); i++) {
329
+ if (est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
330
+ // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
331
+ // type
332
+ LOG_INFO (
333
+ " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
334
+ << in->debugName () << " has type " << est_type_opt[i].value ());
335
+ spec[i].dtype = util::ScalarTypeToTRTDataType (est_type_opt[i].value ());
336
+ } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
337
+ // If we cannot calculate the type and the user did not define the type, then default to FP32
338
+ LOG_WARNING (
339
+ " Cannot infer input type from calcuations in graph for input "
340
+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
341
+ spec[i].dtype = nvinfer1::DataType::kFLOAT ;
342
+ } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
343
+ if (!est_type_opt[i]) {
344
+ LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
345
+ std::stringstream ss;
346
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
347
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
348
+ ss << " . The compiler is going to use the user setting "
349
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
350
+ auto warn_str = ss.str ();
351
+ LOG_WARNING (warn_str);
352
+ // Overwrite type map with user settings
353
+ first_use_type_map[in][i] = {
354
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
355
+
356
+ } else {
357
+ if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
358
+ est_type_opt[i].value ()) {
342
359
std::stringstream ss;
343
360
ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
344
361
ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
345
- ss << " . The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
362
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
363
+ ss << est_type_opt[i].value () << std::endl;
364
+ ss << " The compiler is going to use the user setting "
365
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
366
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
367
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
368
+ ss << " If you do indeed see errors at runtime either:\n " ;
369
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
370
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
346
371
auto warn_str = ss.str ();
347
372
LOG_WARNING (warn_str);
348
373
// Overwrite type map with user settings
349
- first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
350
-
351
- } else {
352
- if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) != est_type_opt[i].value ()) {
353
- std::stringstream ss;
354
- ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
355
- ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
356
- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
357
- ss << est_type_opt[i].value () << std::endl;
358
- ss << " The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
359
- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
360
- ss << " compatibility with PyTorch's data type convention is required.\n " ;
361
- ss << " If you do indeed see errors at runtime either:\n " ;
362
- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
363
- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
364
- auto warn_str = ss.str ();
365
- LOG_WARNING (warn_str);
366
- // Overwrite type map with user settings
367
- first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
368
- }
374
+ first_use_type_map[in][i] = {
375
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
369
376
}
370
- } else {
371
- // The user defined the type so no changes are necessary
372
377
}
378
+ } else {
379
+ // The user defined the type so no changes are necessary
373
380
}
374
381
}
382
+ }
375
383
// }
376
384
}
377
385
@@ -425,12 +433,13 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
425
433
426
434
if (cfg.partition_info .enabled &&
427
435
(!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
428
- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)
429
- || outputIsCollection)) {
430
-
436
+ cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
437
+ outputIsCollection)) {
431
438
std::unordered_map<torch::jit::Node*, int > fallback_nodes;
432
- auto collection_input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
433
- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
439
+ auto collection_input_ivalues_map =
440
+ partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
441
+ auto graph_and_mapping = ConstructFallbackGraph (
442
+ new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
434
443
new_g = graph_and_mapping.first ;
435
444
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
436
445
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments