From f866dba29afa5848ac67d885eaa1e083e2e00177 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Mon, 1 Aug 2022 22:16:17 -0700 Subject: [PATCH] fix: fix the bug that ListConstruct is in TRT subgraph when it's entire graph's output Signed-off-by: Bo Wang --- core/partitioning/partitioning.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 85626772f0..565f58c677 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -90,6 +90,16 @@ std::vector getDependencyNodes( return stk; } +void find_nontensor_output_nodes( + torch::jit::Block* block, + std::unordered_map& global_fallback_nodes) { + for (auto i : block->outputs()) { + if (!isTensor(i)) { + global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR}); + } + } +} + void find_all_fallback_nodes( std::unordered_map& initial_fallback_nodes, std::unordered_map& global_fallback_nodes) { @@ -430,6 +440,9 @@ PartitionedGraph Partition( const PartitionInfo& partition_info, std::unordered_map& global_fallback_nodes) { LOG_DEBUG(partition_info); + // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output + find_nontensor_output_nodes(block, global_fallback_nodes); + // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);