Skip to content

Commit d50498d

Browse files
committed
test: add test case for min_block_size cased fallback nodes
Signed-off-by: Bo Wang <[email protected]>
1 parent e9e2799 commit d50498d

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/core/partitioning/test_segmentation.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
120120
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
121121
}
122122

123+
TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
124+
const auto graph = R"IR(
125+
graph(%0 : Tensor,
126+
%1 : Tensor,
127+
%2 : Tensor):
128+
%3 : int[] = prim::Constant[value=[-1, 5]]()
129+
%4 : int[] = prim::Constant[value=[-1]]()
130+
%5 : int = prim::Constant[value=2]()
131+
%6 : int = prim::Constant[value=4]()
132+
%7 : int = prim::Constant[value=5]()
133+
%8 : int = prim::Constant[value=0]()
134+
%9 : bool = prim::Constant[value=0]()
135+
%10 : NoneType = prim::Constant()
136+
%11 : int = prim::Constant[value=1]()
137+
%12: Tensor = aten::reshape(%1, %4)
138+
%13: Tensor = aten::reshape(%2, %3)
139+
%14: Tensor = aten::reshape(%1, %3)
140+
%15 : Tensor = aten::to(%12, %6, %9, %9, %10)
141+
%16 : int = aten::size(%1, %8)
142+
%17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
143+
%18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
144+
%20 : Tensor = aten::reshape(%18, %17)
145+
return (%20))IR";
146+
147+
auto g = std::make_shared<torch::jit::Graph>();
148+
torch::jit::parseIR(graph, g.get());
149+
150+
torch_tensorrt::core::partitioning::PartitionInfo partition_info;
151+
partition_info.enabled = true;
152+
partition_info.min_block_size = 3;
153+
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
154+
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
155+
torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
156+
ASSERT_TRUE(
157+
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
158+
ASSERT_TRUE(
159+
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
160+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
161+
}
162+
123163
TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
124164
const auto graph = R"IR(
125165
graph(%0 : Tensor,

0 commit comments

Comments
 (0)