@@ -275,81 +275,120 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
275
275
return compile_to_trt;
276
276
}
277
277
278
- std::vector<SegmentedBlock> segment_graph (torch::jit::Block* block, const PartitionInfo& partition_info) {
278
+ bool should_run_in_trt (torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
279
+ // If the op is not supported by the conversion phase it should run in PyTorch
280
+ if (!conversion::OpSupported (n)) {
281
+ LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
282
+ return false ;
283
+ }
284
+
285
+ // If the user specifies the op to run in Torch it should run in PyTorch
286
+ if (torch_ops.find (n->kind ().toQualString ()) != torch_ops.end ()) {
287
+ LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
288
+ return false ;
289
+ }
290
+
291
+ // If the user specifies the module containing this op to run in torch it should run in PyTorch
292
+ const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
293
+ if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
294
+ LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
295
+ return false ;
296
+ }
297
+
298
+ LOG_GRAPH (" Node is going to run in TensorRT: " << util::node_info (n));
299
+ return true ;
300
+ }
301
+
302
+ void finalize_block (PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, std::vector<torch::jit::Node*>& nodes) {
303
+ SegmentedBlock::BlockID b_id= g.size ();
304
+ LOG_DEBUG (" Finalizing in progress " << SegmentedBlock::target_to_str (kind) << " block" );
305
+ g.emplace_back (b_id, kind, nodes);
306
+ nodes.clear ();
307
+ LOG_DEBUG (g.back ());
308
+ }
309
+
310
+ PartitionedGraph segment_graph (torch::jit::Block* block, const PartitionInfo& partition_info) {
279
311
auto min_block_size = partition_info.min_block_size ;
280
- std::unordered_set<std::string> forced_fallback_operators (
312
+ std::unordered_set<std::string> forced_fallback_ops (
281
313
partition_info.forced_fallback_operators .begin (), partition_info.forced_fallback_operators .end ());
282
314
283
315
auto nodes = block->nodes ();
284
- std::vector<SegmentedBlock> segmented_blocks;
316
+ PartitionedGraph segmented_blocks;
285
317
286
318
// segment the nodes
287
- std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes ;
319
+ std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes ;
288
320
for (const auto n : nodes) {
321
+ // Skip constant nodes as they are resources for both kinds of modules
289
322
if (n->kind () == torch::jit::prim::Constant) {
290
323
continue ;
291
324
}
292
325
293
- std::string node_string (n->kind ().toQualString ());
294
- auto has_compile_attribute = n->hasAttribute (c10::Symbol::attr (" to_compile" ));
295
- if (conversion::OpSupported (n) && !forced_fallback_operators.count (node_string) &&
296
- (!has_compile_attribute || n->i (c10::Symbol::attr (" to_compile" )) == (int64_t ) true )) {
297
- tensorrt_nodes.push_back (n);
298
- if (tensorrt_nodes.size () >= min_block_size && !pytorch_nodes.empty ()) {
299
- segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
300
- pytorch_nodes.clear ();
326
+ if (should_run_in_trt (n, forced_fallback_ops)) {
327
+ in_prog_trt_blk_nodes.push_back (n);
328
+
329
+ // If there is an active PyTorch block and we have passed the threshold for a valid TRT
330
+ // block then segment and reset the active PyTorch block
331
+ if (in_prog_trt_blk_nodes.size () >= min_block_size && !in_prog_pyt_blk_nodes.empty ()) {
332
+ finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
301
333
}
302
334
} else {
303
- if (tensorrt_nodes.size () >= min_block_size) {
304
- segmented_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
335
+ // If there is an active TRT block that is valid segment and reset the active TRT block
336
+ // otherwise add it to the active PyTorch block and reset
337
+ if (in_prog_trt_blk_nodes.size () >= min_block_size) {
338
+ finalize_block (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
305
339
} else {
306
- pytorch_nodes.insert (pytorch_nodes.end (), tensorrt_nodes.begin (), tensorrt_nodes.end ());
340
+ LOG_DEBUG (" In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block" );
341
+ in_prog_pyt_blk_nodes.insert (in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
307
342
}
308
- tensorrt_nodes .clear ();
343
+ in_prog_trt_blk_nodes .clear ();
309
344
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
310
345
// we shouldn't inject node for this block in dependency analysis process
311
346
if (n->kind () == torch::jit::prim::If) {
312
- if (!pytorch_nodes. empty ()) {
313
- segmented_blocks. emplace_back (SegmentedBlock:: kTorch , pytorch_nodes);
314
- pytorch_nodes. clear ( );
347
+ LOG_DEBUG ( " Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional " );
348
+ if (!in_prog_pyt_blk_nodes. empty ()) {
349
+ finalize_block (segmented_blocks, SegmentedBlock:: kTorch , in_prog_pyt_blk_nodes );
315
350
}
316
- segmented_blocks.emplace_back (SegmentedBlock::kTorch , std::vector<torch::jit::Node*>{n});
351
+ auto cond_node = std::vector<torch::jit::Node*>{n};
352
+ finalize_block (segmented_blocks, SegmentedBlock::kTorch , cond_node);
317
353
continue ;
318
354
} else if (n->kind () == torch::jit::prim::Loop) {
319
- if (!pytorch_nodes.empty ()) {
320
- segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
321
- pytorch_nodes.clear ();
355
+ if (!in_prog_pyt_blk_nodes.empty ()) {
356
+ finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
322
357
}
323
358
if (checkLoopEvaluatable (n)) {
324
- tensorrt_nodes .push_back (n);
359
+ in_prog_trt_blk_nodes .push_back (n);
325
360
} else {
326
- segmented_blocks.emplace_back (SegmentedBlock::kTorch , std::vector<torch::jit::Node*>{n});
361
+ auto loop_node = std::vector<torch::jit::Node*>{n};
362
+ finalize_block (segmented_blocks, SegmentedBlock::kTorch , loop_node);
327
363
}
328
364
continue ;
329
365
}
330
- pytorch_nodes .push_back (n);
366
+ in_prog_pyt_blk_nodes .push_back (n);
331
367
}
332
368
}
333
369
334
370
// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
335
371
// min_block_size
336
- if (!pytorch_nodes.empty ()) {
337
- pytorch_nodes.insert (pytorch_nodes.end (), tensorrt_nodes.begin (), tensorrt_nodes.end ());
338
- segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
339
- } else {
340
- segmented_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
372
+ if (in_prog_trt_blk_nodes.size () >= min_block_size) {
373
+ finalize_block (segmented_blocks, SegmentedBlock::kTensorRT , in_prog_trt_blk_nodes);
374
+ }
375
+
376
+ if (!in_prog_pyt_blk_nodes.empty ()) {
377
+ in_prog_pyt_blk_nodes.insert (in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
378
+ finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
341
379
}
342
380
343
381
return std::move (segmented_blocks);
344
382
}
345
383
346
- std::vector<SegmentedBlock> Partition (
384
+ PartitionedGraph Partition (
347
385
torch::jit::Block* block,
348
386
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
349
387
const PartitionInfo& partition_info) {
350
388
LOG_DEBUG (partition_info);
351
389
// segment lowering global graph into blocks
352
- std::vector<SegmentedBlock> segmented_blocks = segment_graph (block, partition_info);
390
+ LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
391
+ PartitionedGraph segmented_blocks = segment_graph (block, partition_info);
353
392
354
393
// resolve nonTensor inputs/outputs
355
394
resolveNonTensorInputs (segmented_blocks);
@@ -358,11 +397,22 @@ std::vector<SegmentedBlock> Partition(
358
397
registerSegmentsOutputs (segmented_blocks, block);
359
398
360
399
// run shape analysis on each segmented block
361
- runShapeAnalysis (segmented_blocks, input_ivalues_map);
400
+ runShapeAnalysis (segmented_blocks, input_ivalues_map, at::kFloat );
401
+
402
+ LOG_INFO (segmented_blocks);
362
403
363
404
return segmented_blocks;
364
405
}
365
406
407
+ std::ostream& operator <<(std::ostream& os, const PartitionedGraph& g) {
408
+ os << " Partitioned Graph: [" ;
409
+ for (auto b : g) {
410
+ os << b;
411
+ }
412
+ os << " ]" ;
413
+ return os;
414
+ }
415
+
366
416
} // namespace partitioning
367
417
} // namespace core
368
418
} // namespace trtorch
0 commit comments