@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187
187
return partitioning::stitch (&partitioning_ctx, block);
188
188
}
189
189
190
- void MapInputsAndDetermineDTypes (
190
+ ir::TypeMap MapInputsAndDetermineDTypes (
191
191
CompileSpec& cfg,
192
192
std::shared_ptr<torch::jit::Graph>& g,
193
193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197
197
cfg.partitioning_info .collection_input_spec_map =
198
198
ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
199
199
200
+ ir::TypeMap inferred_dtypes;
200
201
auto collection_inputs = ir::get_collection_inputs (g, static_params);
201
202
LOG_DEBUG (
202
203
" In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -239,34 +240,36 @@ void MapInputsAndDetermineDTypes(
239
240
first_use_type_map[in][i] = {
240
241
util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
241
242
242
- } else {
243
- if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
244
- est_type_opt[i].value ()) {
245
- std::stringstream ss;
246
- ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
247
- ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
248
- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
249
- ss << est_type_opt[i].value () << std::endl;
250
- ss << " The compiler is going to use the user setting "
251
- << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
252
- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
253
- ss << " compatibility with PyTorch's data type convention is required.\n " ;
254
- ss << " If you do indeed see errors at runtime either:\n " ;
255
- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
256
- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
257
- auto warn_str = ss.str ();
258
- LOG_WARNING (warn_str);
259
- // Overwrite type map with user settings
260
- first_use_type_map[in][i] = {
261
- util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
262
- }
243
+ } else if (
244
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
245
+ est_type_opt[i].value ()) {
246
+ std::stringstream ss;
247
+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
248
+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
249
+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
250
+ ss << est_type_opt[i].value () << std::endl;
251
+ ss << " The compiler is going to use the user setting "
252
+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
253
+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
254
+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
255
+ ss << " If you do indeed see errors at runtime either:\n " ;
256
+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
257
+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
258
+ auto warn_str = ss.str ();
259
+ LOG_WARNING (warn_str);
260
+ // Overwrite type map with user settings
261
+ first_use_type_map[in][i] = {
262
+ util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
263
263
}
264
264
} else {
265
265
// The user defined the type so no changes are necessary
266
266
}
267
+
268
+ // Insert entry for Value pointer and determined ScalarType
269
+ inferred_dtypes.insert ({in, c10::optional<c10::ScalarType>(util::TRTDataTypeToScalarType (spec[i].dtype ))});
267
270
}
268
271
}
269
- // }
272
+ return inferred_dtypes;
270
273
}
271
274
272
275
std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -307,7 +310,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307
310
// Infer the type of an input from the weights of the calculation
308
311
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
309
312
310
- MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
313
+ // Extract map of IValue to DType
314
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
315
+
316
+ // Use dtype map to autocast inputs to the correct type
317
+ if (cfg.partitioning_info .enabled && cfg.partitioning_info .truncate_long_and_double ) {
318
+ lowering::AutocastInputs (g, type_map, cfg.lower_info .getGPUDeviceString ());
319
+ }
320
+
311
321
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
312
322
auto outputIsCollection = conversion::OutputIsCollection (g->block ());
313
323
if (cfg.partitioning_info .enabled &&
0 commit comments