diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 85626772f0..565f58c677 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -90,6 +90,16 @@ std::vector getDependencyNodes( return stk; } +void find_nontensor_output_nodes( + torch::jit::Block* block, + std::unordered_map& global_fallback_nodes) { + for (auto i : block->outputs()) { + if (!isTensor(i)) { + global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR}); + } + } +} + void find_all_fallback_nodes( std::unordered_map& initial_fallback_nodes, std::unordered_map& global_fallback_nodes) { @@ -430,6 +440,9 @@ PartitionedGraph Partition( const PartitionInfo& partition_info, std::unordered_map& global_fallback_nodes) { LOG_DEBUG(partition_info); + // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output + find_nontensor_output_nodes(block, global_fallback_nodes); + // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);