11
11
12
12
#include " torch/csrc/jit/frontend/function_schema_parser.h"
13
13
#include " torch/csrc/jit/ir/ir.h"
14
+ #include " torch/csrc/jit/ir/ir_views.h"
14
15
#include " torch/csrc/jit/passes/graph_fuser.h"
15
16
#include " torch/csrc/jit/passes/loop_unrolling.h"
16
17
#include " torch/csrc/jit/passes/lower_graph.h"
@@ -173,10 +174,131 @@ void AddSegmentedBlockToGraph(
173
174
for (size_t i = 0 ; i < seg.raw_outputs ().size (); ++i) {
174
175
old_to_new_g[seg.raw_outputs ()[i]] = mini_to_new_g[seg.outputs ()[i]];
175
176
}
177
+ size_t offset = seg.target () == partitioning::SegmentedBlock::kTensorRT ? 1 : 0 ;
178
+ for (size_t i = 0 ; i < seg.raw_inputs ().size (); ++i) {
179
+ if (!old_to_new_g.count (seg.raw_inputs ()[i])) {
180
+ old_to_new_g[seg.raw_inputs ()[i]] = mini_to_new_g[seg.inputs ()[i + offset]];
181
+ }
182
+ }
176
183
177
184
return ;
178
185
}
179
186
187
+ typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
188
+ GraphAndMapping;
189
+
190
+ void AddIfBlockToGraph (
191
+ std::shared_ptr<torch::jit::Graph>& new_g,
192
+ torch::jit::Node* if_node,
193
+ const std::vector<GraphAndMapping>& graph_and_mappings,
194
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
195
+ torch::jit::IfView if_view (if_node);
196
+
197
+ // create a new if node in new_g and add corresponding inputs
198
+ auto new_if = new_g->insertNode (new_g->create (torch::jit::prim::If, {}, 0 ));
199
+ new_if->addInput (util::getOrAddInputForValue (if_view.cond (), new_g, old_to_new_g));
200
+
201
+ // iterate over all blocks and add them to new created prim::If
202
+ for (auto graph_and_mapping : graph_and_mappings) {
203
+ auto new_if_block = new_if->addBlock ();
204
+ auto cur_block_graph = graph_and_mapping.first ;
205
+ auto cur_block_mapping = graph_and_mapping.second ;
206
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
207
+ for (auto & i : cur_block_mapping) {
208
+ // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
209
+ // it's mini graph's input
210
+ if (old_to_new_g.count (i.first )) {
211
+ block_graph_to_new_g[i.second ] = old_to_new_g[i.first ];
212
+ }
213
+ }
214
+
215
+ auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
216
+ new_if_block->cloneFrom (cur_block_graph->block (), env);
217
+ if (cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
218
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
219
+ auto self = new_g->insertInput (0 , " self_1" );
220
+ self->setType (cur_block_graph->inputs ()[0 ]->type ());
221
+ }
222
+ block_graph_to_new_g[cur_block_graph->inputs ()[0 ]] = new_g->inputs ()[0 ];
223
+ }
224
+ for (int i = cur_block_graph->inputs ().size () - 1 ; i >= 0 ; --i) {
225
+ new_if_block->inputs ()[i]->replaceAllUsesWith (block_graph_to_new_g[cur_block_graph->inputs ()[i]]);
226
+ new_if_block->eraseInput (i);
227
+ }
228
+ }
229
+ for (auto ov : if_view.outputs ()) {
230
+ auto no = new_if->addOutput ();
231
+ old_to_new_g[ov] = no;
232
+ no->copyMetadata (ov);
233
+ }
234
+ return ;
235
+ }
236
+
237
+ GraphAndMapping ConstructFallbackGraph (
238
+ torch::jit::script::Module& new_mod,
239
+ torch::jit::Block* block,
240
+ std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
241
+ CompileSpec cfg,
242
+ conversion::GraphParams named_params) {
243
+ auto convert_cfg = cfg.convert_info ;
244
+ auto partition_info = cfg.partition_info ;
245
+
246
+ auto new_g = std::make_shared<torch::jit::Graph>();
247
+
248
+ auto segmented_blocks = partitioning::Partition (block, input_ivalues_map, partition_info);
249
+
250
+ // the mapping from lowering graph => fallback global graph
251
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
252
+ for (auto input : block->inputs ()) {
253
+ util::getOrAddInputForValue (input, new_g, old_to_new_g);
254
+ }
255
+
256
+ for (auto & seg_block : segmented_blocks) {
257
+ LOG_INFO (*seg_block.g () << " (GraphInSegmentedBlock)\n " );
258
+ std::ostringstream trt_engine_id;
259
+ trt_engine_id << reinterpret_cast <const int *>(&seg_block);
260
+
261
+ if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
262
+ std::vector<ir::Input> inputs;
263
+ for (auto & shape : seg_block.in_shape ()) {
264
+ inputs.push_back (ir::Input (shape));
265
+ }
266
+ // update the input ranges for each segments
267
+ convert_cfg.inputs = inputs;
268
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params);
269
+ auto temp_g = std::make_shared<torch::jit::Graph>();
270
+ auto device_spec = convert_cfg.engine_settings .device ;
271
+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
272
+ AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
273
+
274
+ seg_block.update_graph (temp_g);
275
+ AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
276
+ } else {
277
+ if (seg_block.raw_nodes ()[0 ]->kind () == torch::jit::prim::If) {
278
+ auto if_node = seg_block.raw_nodes ()[0 ];
279
+
280
+ // convert the 2 blocks in prim::if and get the converted graph with mappings
281
+ std::vector<GraphAndMapping> graph_and_mappings;
282
+ for (auto cur_block : if_node->blocks ()) {
283
+ graph_and_mappings.push_back (
284
+ ConstructFallbackGraph (new_mod, cur_block, input_ivalues_map, cfg, named_params));
285
+ }
286
+ AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
287
+
288
+ } else {
289
+ AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
290
+ }
291
+ }
292
+ }
293
+
294
+ for (auto & output : block->outputs ()) {
295
+ if (old_to_new_g.count (output)) {
296
+ new_g->registerOutput (old_to_new_g[output]);
297
+ }
298
+ }
299
+ return {new_g, old_to_new_g};
300
+ }
301
+
180
302
torch::jit::script::Module CompileGraphWithFallback (const torch::jit::script::Module& mod, CompileSpec cfg) {
181
303
// TODO: Should be doing a functional transform but need PR #31978
182
304
// [jit] More robust mangling
@@ -192,53 +314,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
192
314
auto g = graph_and_parameters.first ;
193
315
auto params = graph_and_parameters.second ;
194
316
auto named_params = conversion::get_named_params (g->inputs (), params);
195
- auto convert_cfg = std::move (cfg.convert_info );
196
- LOG_INFO (*g << " (LoweringGraph)\n " );
317
+ LOG_INFO (" (LoweredGraph)\n " << *g);
197
318
198
- // segment the graph and convert segmented TensorRT block
199
- auto segmented_blocks = partitioning::Partition (g, convert_cfg.inputs , cfg.partition_info );
200
- if (segmented_blocks.size () == 1 && segmented_blocks[0 ].target () == partitioning::SegmentedBlock::kTorch ) {
319
+ std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320
+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
321
+ inputs.insert ({g->inputs ()[i], cfg.convert_info .inputs [i]});
322
+ }
323
+ auto input_ivalues_map = partitioning::generateRandomInputs (inputs);
324
+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, named_params);
325
+ new_g = graph_and_mapping.first ;
326
+ LOG_INFO (" (FallbackGraph)\n " << *new_g);
327
+
328
+ // if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329
+ // module
330
+ if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
201
331
LOG_WARNING (" Didn't generate any TensorRT engines, the compiler did nothing\n " );
202
332
return mod;
203
333
}
204
334
205
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
206
- // add global graph's input to old_to_new_g mapping
207
- for (auto input : g->inputs ()) {
208
- util::getOrAddInputForValue (input, new_g, old_to_new_g);
209
- }
210
- for (auto & seg_block : segmented_blocks) {
211
- std::string cur_block_target =
212
- seg_block.target () == partitioning::SegmentedBlock::kTensorRT ? " TensorRT" : " Torch" ;
213
- LOG_INFO (*seg_block.g () << " (Sub Graph" << cur_block_target << " Block)\n " );
214
- std::ostringstream trt_engine_id;
215
- trt_engine_id << reinterpret_cast <const int *>(&seg_block);
216
- if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
217
- std::vector<ir::Input> inputs;
218
- for (auto & shape : seg_block.in_shape ()) {
219
- inputs.push_back (ir::Input (shape));
220
- }
221
- // update the input ranges for each segments
222
- convert_cfg.inputs = inputs;
223
- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params);
224
- auto temp_g = std::make_shared<torch::jit::Graph>();
225
- auto device_spec = convert_cfg.engine_settings .device ;
226
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
227
- AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
228
-
229
- seg_block.update_graph (temp_g);
230
- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
231
- } else {
232
- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
233
- }
234
- }
235
-
236
- for (auto & output : g->outputs ()) {
237
- new_g->registerOutput (old_to_new_g[output]);
238
- }
239
-
240
- LOG_INFO (*new_g << " (FallbackGraph)\n " );
241
-
242
335
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
243
336
auto schema = util::GenerateGraphSchema (new_method->name (), new_g);
244
337
new_mod.type ()->addMethod (new_method);
0 commit comments