@@ -90,6 +90,16 @@ std::vector<torch::jit::Node*> getDependencyNodes(
90
90
return stk;
91
91
}
92
92
93
+ void find_nontensor_output_nodes (
94
+ torch::jit::Block* block,
95
+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
96
+ for (auto i : block->outputs ()) {
97
+ if (!isTensor (i)) {
98
+ global_fallback_nodes.insert ({i->node (), FallbackNodeType::kNON_TENSOR });
99
+ }
100
+ }
101
+ }
102
+
93
103
void find_all_fallback_nodes (
94
104
std::unordered_map<torch::jit::Node*, int >& initial_fallback_nodes,
95
105
std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
@@ -430,6 +440,9 @@ PartitionedGraph Partition(
430
440
const PartitionInfo& partition_info,
431
441
std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
432
442
LOG_DEBUG (partition_info);
443
+ // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
444
+ find_nontensor_output_nodes (block, global_fallback_nodes);
445
+
433
446
// segment lowering global graph into blocks
434
447
LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
435
448
PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes);
0 commit comments