6
6
7
7
// TODO: IR Parser doesnt work well with neg numbers
8
8
TEST (Converters, ATenFlattenConvertsCorrectly) {
9
- const auto graph = R"IR(
10
- graph(%0 : Tensor):
11
- %1 : int = prim::Constant[value=0]()
12
- %2 : int = prim::Constant[value=1]()
13
- %3 : Tensor = aten::flatten(%0, %1, %2)
14
- return (%3))IR" ;
15
-
16
- auto g = std::make_shared<torch::jit::Graph>();
17
- torch::jit::parseIR (graph, &*g);
9
+ const auto graph = R"IR(
10
+ graph(%0 : Tensor):
11
+ %1 : int = prim::Constant[value=0]()
12
+ %2 : int = prim::Constant[value=1]()
13
+ %3 : Tensor = aten::flatten(%0, %1, %2)
14
+ return (%3))IR" ;
18
15
19
- auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
20
- auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
21
- auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
16
+ auto g = std::make_shared<torch::jit::Graph>();
17
+ torch::jit::parseIR (graph, &*g);
22
18
23
- in = at::clone (in);
24
- params = trtorch::core::conversion::get_named_params (g->inputs (), {});
25
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
26
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
19
+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
20
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
21
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
27
22
28
- ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
23
+ in = at::clone (in);
24
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
25
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
26
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
27
+
28
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
29
29
}
30
30
31
31
// TODO: IR Parser doesnt work well with neg numbers
@@ -164,4 +164,27 @@ TEST(Converters, ATenPermute5DConvertsCorrectly) {
164
164
auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
165
165
166
166
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
167
- }
167
+ }
168
+
169
+ TEST (Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {
170
+ const auto graph = R"IR(
171
+ graph(%0 : Tensor):
172
+ %1 : int = prim::Constant[value=0]()
173
+ %2 : int = prim::Constant[value=1]()
174
+ %3 : Tensor = aten::flatten(%0, %1, %2)
175
+ return (%3))IR" ;
176
+
177
+ auto g = std::make_shared<torch::jit::Graph>();
178
+ torch::jit::parseIR (graph, &*g);
179
+
180
+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
181
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
182
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
183
+
184
+ in = at::clone (in);
185
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
186
+ auto trt_results = trtorch::tests::util::RunGraphEngineDynamic (g, params, {in});
187
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
188
+
189
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
190
+ }
0 commit comments