@@ -128,22 +128,6 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128
128
return conversion::VerifyConverterSupportForBlock (g->block ());
129
129
}
130
130
131
- std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132
- // Go through Lowering to simplify graph and extract weight parameters
133
- auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
134
-
135
- auto convert_cfg = std::move (cfg.convert_info );
136
- auto g = graph_and_parameters.first ;
137
-
138
- auto params = graph_and_parameters.second ;
139
- auto named_params = conversion::get_named_params (g->inputs (), params);
140
-
141
- LOG_INFO (*g << " (CompileGraph)\n " );
142
-
143
- auto engine = conversion::ConvertBlockToEngine (g->block (), convert_cfg, named_params);
144
- return std::move (engine);
145
- }
146
-
147
131
void AddSegmentedBlockToGraph (
148
132
std::shared_ptr<torch::jit::Graph>& g,
149
133
partitioning::SegmentedBlock& seg,
@@ -237,15 +221,15 @@ void AddIfBlockToGraph(
237
221
GraphAndMapping ConstructFallbackGraph (
238
222
torch::jit::script::Module& new_mod,
239
223
torch::jit::Block* block,
240
- std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map ,
224
+ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map ,
241
225
CompileSpec cfg,
242
- conversion::GraphParams named_params ) {
226
+ ir::StaticParams static_params ) {
243
227
auto convert_cfg = cfg.convert_info ;
244
228
auto partition_info = cfg.partition_info ;
245
229
246
230
auto new_g = std::make_shared<torch::jit::Graph>();
247
231
248
- auto segmented_blocks = partitioning::Partition (block, input_ivalues_map , partition_info);
232
+ auto segmented_blocks = partitioning::Partition (block, example_tensor_map , partition_info);
249
233
250
234
// the mapping from lowering graph => fallback global graph
251
235
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -259,13 +243,17 @@ GraphAndMapping ConstructFallbackGraph(
259
243
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
260
244
261
245
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
246
+ auto shapes = seg_block.in_shapes ();
247
+ auto types = seg_block.in_types ();
262
248
std::vector<ir::Input> inputs;
263
- for (auto & shape : seg_block.in_shape ()) {
264
- inputs.push_back (ir::Input (shape));
249
+ for (size_t i = 0 ; i < shapes.size (); i++) {
250
+ auto in = ir::Input (shapes[i]);
251
+ in.dtype = util::ScalarTypeToTRTDataType (types[i]);
252
+ inputs.push_back (in);
265
253
}
266
254
// update the input ranges for each segments
267
- convert_cfg.inputs = inputs;
268
- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params );
255
+ convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block. g (), inputs, static_params) ;
256
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params );
269
257
auto temp_g = std::make_shared<torch::jit::Graph>();
270
258
auto device_spec = convert_cfg.engine_settings .device ;
271
259
auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
@@ -281,7 +269,7 @@ GraphAndMapping ConstructFallbackGraph(
281
269
std::vector<GraphAndMapping> graph_and_mappings;
282
270
for (auto cur_block : if_node->blocks ()) {
283
271
graph_and_mappings.push_back (
284
- ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map , cfg, named_params ));
272
+ ConstructFallbackGraph (new_mod, cur_block, example_tensor_map , cfg, static_params ));
285
273
}
286
274
AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
287
275
@@ -299,54 +287,28 @@ GraphAndMapping ConstructFallbackGraph(
299
287
return {new_g, old_to_new_g};
300
288
}
301
289
302
- torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
303
- // TODO: Should be doing a functional transform but need PR #31978
304
- // [jit] More robust mangling
305
- // torch::jit::script::Module new_mod = mod.clone();
306
- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
307
- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
308
- for (const torch::jit::script::Method& method : mod.get_methods ()) {
309
- // Compile only forward methods. forward method contains the entire graph.
310
- if (method.name ().compare (" forward" ) == 0 ) {
311
- auto new_g = std::make_shared<torch::jit::Graph>();
312
- auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
290
+ std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
291
+ // Go through Lowering to simplify graph and extract weight parameters
292
+ auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
313
293
314
- auto g = graph_and_parameters.first ;
315
- auto params = graph_and_parameters.second ;
316
- auto named_params = conversion::get_named_params (g->inputs (), params);
317
- LOG_INFO (" (LoweredGraph)\n " << *g);
294
+ auto convert_cfg = std::move (cfg.convert_info );
295
+ auto g = graph_and_parameters.first ;
318
296
319
- std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320
- for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
321
- inputs.insert ({g->inputs ()[i], cfg.convert_info .inputs [i]});
322
- }
323
- auto input_ivalues_map = partitioning::generateRandomInputs (inputs);
324
- auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, named_params);
325
- new_g = graph_and_mapping.first ;
326
- LOG_INFO (" (FallbackGraph)\n " << *new_g);
297
+ auto params = graph_and_parameters.second ;
298
+ auto static_params = ir::get_static_params (g->inputs (), params);
327
299
328
- // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329
- // module
330
- if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
331
- LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
332
- return mod;
333
- }
300
+ LOG_INFO (*g << " (CompileGraph)\n " );
334
301
335
- auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
336
- auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
337
- new_mod.type ()->addMethod (new_method);
338
- new_method->setSchema (schema);
339
- }
340
- }
302
+ // Move the user defined inputs to the convert_cfg since some might be static;
303
+ convert_cfg.inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
341
304
342
- return new_mod;
305
+ auto engine = conversion::ConvertBlockToEngine (g->block (), convert_cfg, static_params);
306
+ return std::move (engine);
343
307
}
344
308
345
- torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
346
- // TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally
347
- if (cfg.partition_info .enabled ) {
348
- return CompileGraphWithFallback (mod, cfg);
349
- }
309
+ torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
310
+ torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
311
+
350
312
auto device_spec = cfg.convert_info .engine_settings .device ;
351
313
352
314
// GPU default WS size : 1 GB
@@ -362,25 +324,59 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
362
324
}
363
325
}
364
326
365
- // TODO: Should be doing a functional transform but need PR #31978
366
- // [jit] More robust mangling
367
- // torch::jit::script::Module new_mod = mod.clone();
368
- torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
369
- std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
370
- for (const torch::jit::script::Method& method : mod.get_methods ()) {
371
- // Compile only forward methods. forward method contains the entire graph.
327
+ for (const torch::jit::Method& method : mod.get_methods ()) {
372
328
if (method.name ().compare (" forward" ) == 0 ) {
373
- auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
374
329
auto new_g = std::make_shared<torch::jit::Graph>();
375
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
376
- AddEngineToGraph (new_mod, new_g, engine, cuda_device);
330
+
331
+ auto graph_and_parameters = lowering::Lower (mod, method.name (), cfg.lower_info );
332
+
333
+ auto g = graph_and_parameters.first ;
334
+ LOG_INFO (" Lowered Graph: " << *g);
335
+ auto params = graph_and_parameters.second ;
336
+ auto static_params = ir::get_static_params (g->inputs (), params);
337
+
338
+ cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
339
+
340
+ // If the user did not explicitly set the input type, then use the first
341
+ // tensor calculation to infer type.
342
+ auto first_use_types = util::get_block_first_calc_dtypes_opt (g->block ());
343
+ for (auto & in : g->inputs ()) {
344
+ auto est_type_opt = first_use_types[in];
345
+ ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
346
+ if (est_type_opt && !spec.dtype_is_user_defined ) {
347
+ spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
348
+ } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
349
+ LOG_WARNING (
350
+ " Cannot deterime input type from calcuations in graph for input "
351
+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
352
+ spec.dtype = nvinfer1::DataType::kFLOAT ;
353
+ }
354
+ }
355
+
356
+ if (cfg.partition_info .enabled ) {
357
+ auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
358
+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
359
+ new_g = graph_and_mapping.first ;
360
+ LOG_INFO (" Segmented Graph: " << *new_g);
361
+
362
+ // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
363
+ // module
364
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
365
+ LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
366
+ return mod;
367
+ }
368
+ } else {
369
+ auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
370
+ auto device_spec = cfg.convert_info .engine_settings .device ;
371
+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
372
+ AddEngineToGraph (new_mod, new_g, engine, cuda_device);
373
+ }
377
374
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
378
375
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
379
376
new_mod.type ()->addMethod (new_method);
380
377
new_method->setSchema (schema);
381
378
}
382
379
}
383
-
384
380
return new_mod;
385
381
}
386
382
0 commit comments