@@ -90,14 +90,26 @@ std::vector<torch::jit::Node*> getDependencyNodes(
90
90
return stk;
91
91
}
92
92
93
- void find_nontensor_output_nodes (
93
+ // check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
94
+ // nodes
95
+ void fallback_graph_nontensor_in_out (
94
96
torch::jit::Block* block,
95
97
std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
98
+ // fallback nodes that produce entire graph's nonTensor output
96
99
for (auto i : block->outputs ()) {
97
100
if (!isTensor (i)) {
98
101
global_fallback_nodes.insert ({i->node (), FallbackNodeType::kNON_TENSOR });
99
102
}
100
103
}
104
+
105
+ // fallback nodes that consume entire graph's nonTensor input
106
+ for (auto i : block->inputs ()) {
107
+ if (!isTensor (i)) {
108
+ for (auto use : i->uses ()) {
109
+ global_fallback_nodes.insert ({use.user , FallbackNodeType::kNON_TENSOR });
110
+ }
111
+ }
112
+ }
101
113
}
102
114
103
115
void find_all_fallback_nodes (
@@ -202,6 +214,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
202
214
}
203
215
}
204
216
}
217
+
205
218
std::for_each (segmented_blocks.begin (), segmented_blocks.end (), [](SegmentedBlock& seg_block) {
206
219
torch::jit::EliminateDeadCode (seg_block.g ());
207
220
});
@@ -440,8 +453,9 @@ PartitionedGraph Partition(
440
453
const PartitionInfo& partition_info,
441
454
std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
442
455
LOG_DEBUG (partition_info);
443
- // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
444
- find_nontensor_output_nodes (block, global_fallback_nodes);
456
+ // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor
457
+ // output
458
+ fallback_graph_nontensor_in_out (block, global_fallback_nodes);
445
459
446
460
// segment lowering global graph into blocks
447
461
LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
0 commit comments