@@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
120
120
ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 }, {3 , 4 }}));
121
121
}
122
122
123
+ TEST (Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
124
+ const auto graph = R"IR(
125
+ graph(%0 : Tensor,
126
+ %1 : Tensor,
127
+ %2 : Tensor):
128
+ %3 : int[] = prim::Constant[value=[-1, 5]]()
129
+ %4 : int[] = prim::Constant[value=[-1]]()
130
+ %5 : int = prim::Constant[value=2]()
131
+ %6 : int = prim::Constant[value=4]()
132
+ %7 : int = prim::Constant[value=5]()
133
+ %8 : int = prim::Constant[value=0]()
134
+ %9 : bool = prim::Constant[value=0]()
135
+ %10 : NoneType = prim::Constant()
136
+ %11 : int = prim::Constant[value=1]()
137
+ %12: Tensor = aten::reshape(%1, %4)
138
+ %13: Tensor = aten::reshape(%2, %3)
139
+ %14: Tensor = aten::reshape(%1, %3)
140
+ %15 : Tensor = aten::to(%12, %6, %9, %9, %10)
141
+ %16 : int = aten::size(%1, %8)
142
+ %17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
143
+ %18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
144
+ %20 : Tensor = aten::reshape(%18, %17)
145
+ return (%20))IR" ;
146
+
147
+ auto g = std::make_shared<torch::jit::Graph>();
148
+ torch::jit::parseIR (graph, g.get ());
149
+
150
+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
151
+ partition_info.enabled = true ;
152
+ partition_info.min_block_size = 3 ;
153
+ std::unordered_map<torch::jit::Node*, int > fallback_nodes;
154
+ std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
155
+ torch_tensorrt::core::partitioning::segment_graph (g->block (), partition_info, fallback_nodes);
156
+ ASSERT_TRUE (
157
+ checkSegmentedBlockNumber (segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT , 1 ));
158
+ ASSERT_TRUE (
159
+ checkSegmentedBlockNumber (segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch , 1 ));
160
+ ASSERT_TRUE (checkSegmentedBlockNodesMapping (segmented_blocks, g, {{0 , 1 , 2 , 3 }, {4 , 5 , 6 , 7 }}));
161
+ }
162
+
123
163
TEST (Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
124
164
const auto graph = R"IR(
125
165
graph(%0 : Tensor,
0 commit comments