@@ -17,22 +17,17 @@ struct usage_info {
17
17
std::vector<size_t > tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
18
18
};
19
19
20
- inline bool isTensorOrTensorList (torch::jit::Value* val) {
21
- return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
22
- val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
23
- }
24
-
25
- inline bool isTensorList (torch::jit::Value* val) {
26
- return val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
27
- }
28
-
29
20
inline bool isTensor (torch::jit::Value* val) {
30
21
return val->type ()->isSubtypeOf (torch::jit::TensorType::get ());
31
22
}
32
23
24
+ inline bool isListOrTuple (torch::jit::Value* val) {
25
+ return val->type ()->kind () == torch::jit::TypeKind::TupleType || val->type ()->kind () == torch::jit::TypeKind::ListType;
26
+ }
27
+
33
28
bool containNonTensorOutputs (torch::jit::Node* n) {
34
29
for (auto output : n->outputs ()) {
35
- if (!isTensorOrTensorList (output)) {
30
+ if (!isTensor (output)) {
36
31
return true ;
37
32
}
38
33
}
@@ -68,6 +63,7 @@ std::vector<torch::jit::Node*> findModifyingNodes(
68
63
return modifying_nodes;
69
64
}
70
65
66
+ // this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
71
67
std::vector<torch::jit::Node*> getDependencyNodes (
72
68
const std::vector<torch::jit::Value*>& vals,
73
69
const SegmentedBlock& seg_block) {
@@ -88,7 +84,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
88
84
stk.insert (stk.end (), modifying_nodes.rbegin (), modifying_nodes.rend ());
89
85
stk.push_back (node);
90
86
for (auto input : node->inputs ()) {
91
- if (!isTensorOrTensorList (input)) {
87
+ if (!isTensor (input)) {
92
88
q.push (input);
93
89
}
94
90
}
@@ -113,15 +109,19 @@ void find_all_fallback_nodes(
113
109
auto cur_node = q.front ();
114
110
q.pop ();
115
111
// for every node that produces this fallback node's NonTensor input, they should fallback too
112
+ // Even collection feature is supported, since TRT List/Tuple output is not supported yet, the nodes
113
+ // that produce List/Tuple still cannot be in TRT segment
116
114
for (auto input : cur_node->inputs ()) {
117
115
if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
118
116
global_fallback_nodes.insert ({input->node (), FallbackNodeType::kNON_TENSOR }).second ) {
119
117
q.push (input->node ());
120
118
}
121
119
}
122
120
// for every node that consumes this fallback node's NonTensor output, they should fallback too
121
+ // Since collection feature is supported, we can have List/Tuple input for TRT segment, so we only
122
+ // fallback the nodes that take inputs which are not Tensor/List/Tuple
123
123
for (auto output : cur_node->outputs ()) {
124
- if (!isTensor (output)) {
124
+ if (!isTensor (output) && ! isListOrTuple (output) ) {
125
125
for (auto use : output->uses ()) {
126
126
auto node = use.user ;
127
127
if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, FallbackNodeType::kNON_TENSOR }).second ) {
@@ -176,7 +176,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
176
176
if (std::find (seg_block.raw_inputs ().begin (), seg_block.raw_inputs ().end (), mini_graph_input) ==
177
177
seg_block.raw_inputs ().end () &&
178
178
seg_block.contain_raw_value (mini_graph_input)) {
179
- if (!isTensorOrTensorList (mini_graph_input) && seg_block.target () == SegmentedBlock::kTensorRT )
179
+ if (!isTensor (mini_graph_input) && seg_block.target () == SegmentedBlock::kTensorRT )
180
180
continue ;
181
181
seg_block.registerOutput (mini_graph_input);
182
182
}
@@ -242,36 +242,6 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
242
242
" Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
243
243
<< util::node_info (n));
244
244
}
245
- }
246
- return false ;
247
- }
248
-
249
- bool is_collection (torch::jit::Node* n) {
250
- for (auto out: n->outputs ()) {
251
- if (out->type ()->kind () == torch::jit::TypeKind::TupleType || out->type ()->kind () == torch::jit::TypeKind::ListType) {
252
- return true ;
253
- }
254
- }
255
- return false ;
256
- }
257
-
258
- bool should_run_in_trt (torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
259
- // If the op is not supported by the conversion phase it should run in PyTorch
260
- if (!conversion::OpSupported (n)) {
261
- LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
262
- return false ;
263
- }
264
-
265
- // If the user specifies the op to run in Torch it should run in PyTorch
266
- if (torch_ops.find (n->kind ().toQualString ()) != torch_ops.end ()) {
267
- LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
268
- return false ;
269
- }
270
-
271
- // If the user specifies the module containing this op to run in torch it should run in PyTorch
272
- const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
273
- if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
274
- LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
275
245
return false ;
276
246
}
277
247
@@ -390,19 +360,18 @@ PartitionedGraph segment_graph(
390
360
find_min_block_size_fallback_nodes (block, global_fallback_nodes, min_block_size);
391
361
392
362
auto nodes = block->nodes ();
393
- auto reverse_nodes = nodes.reverse (); // merge from output side to input side
394
363
PartitionedGraph segmented_blocks;
395
364
396
365
// segment the nodes
397
366
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
398
- for (const auto n : reverse_nodes ) {
367
+ for (const auto n : nodes ) {
399
368
// Skip constant nodes as they are resources for both kinds of modules
400
369
if (n->kind () == torch::jit::prim::Constant) {
401
370
continue ;
402
371
}
403
372
// the outputs of trt subgraph shouldn't be collections
404
- if (should_run_in_trt (n, forced_fallback_ops) && !(in_prog_trt_blk_nodes. size () == 0 && is_collection (n) )) {
405
- in_prog_trt_blk_nodes.insert (in_prog_trt_blk_nodes. begin (), n);
373
+ if (check_node_fallback (n, global_fallback_nodes )) {
374
+ in_prog_trt_blk_nodes.push_back ( n);
406
375
407
376
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
408
377
// block then segment and reset the active PyTorch block
@@ -418,7 +387,7 @@ PartitionedGraph segment_graph(
418
387
LOG_DEBUG (
419
388
" In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block" );
420
389
in_prog_pyt_blk_nodes.insert (
421
- in_prog_pyt_blk_nodes.begin (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
390
+ in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
422
391
}
423
392
in_prog_trt_blk_nodes.clear ();
424
393
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
@@ -437,14 +406,14 @@ PartitionedGraph segment_graph(
437
406
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
438
407
}
439
408
if (checkLoopEvaluatable (n)) {
440
- in_prog_trt_blk_nodes.insert (in_prog_trt_blk_nodes. begin (), n);
409
+ in_prog_trt_blk_nodes.push_back ( n);
441
410
} else {
442
411
auto loop_node = std::vector<torch::jit::Node*>{n};
443
412
finalize_block (segmented_blocks, SegmentedBlock::kTorch , loop_node);
444
413
}
445
414
continue ;
446
415
}
447
- in_prog_pyt_blk_nodes.insert (in_prog_pyt_blk_nodes. begin (), n);
416
+ in_prog_pyt_blk_nodes.push_back ( n);
448
417
}
449
418
}
450
419
@@ -459,7 +428,6 @@ PartitionedGraph segment_graph(
459
428
in_prog_pyt_blk_nodes.end (), in_prog_trt_blk_nodes.begin (), in_prog_trt_blk_nodes.end ());
460
429
finalize_block (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
461
430
}
462
- std::reverse (segmented_blocks.begin (), segmented_blocks.end ());
463
431
return segmented_blocks;
464
432
}
465
433
0 commit comments