4
4
#include " tests/util/util.h"
5
5
#include " torch/csrc/jit/ir/irparser.h"
6
6
7
- // TODO: IR Parser doesnt work well with neg numbers
8
7
TEST (Converters, ATenFlattenConvertsCorrectly) {
9
8
const auto graph = R"IR(
10
9
graph(%0 : Tensor):
@@ -23,12 +22,32 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {
23
22
in = at::clone (in);
24
23
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
25
24
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
26
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
27
25
28
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
26
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
27
+ }
28
+
29
+ TEST (Converters, ATenFlattenNegDimsConvertsCorrectly) {
30
+ const auto graph = R"IR(
31
+ graph(%0 : Tensor):
32
+ %1 : int = prim::Constant[value=-3]()
33
+ %2 : int = prim::Constant[value=-2]()
34
+ %3 : Tensor = aten::flatten(%0, %1, %2)
35
+ return (%3))IR" ;
36
+
37
+ auto g = std::make_shared<torch::jit::Graph>();
38
+ torch::jit::parseIR (graph, g.get ());
39
+
40
+ auto in = at::randint (0 , 5 , {2 , 3 , 3 }, {at::kCUDA });
41
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
42
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
43
+
44
+ in = at::clone (in);
45
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
46
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
47
+
48
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
29
49
}
30
50
31
- // TODO: IR Parser doesnt work well with neg numbers
32
51
TEST (Converters, ATenFlattenOtherDimsConvertsCorrectly) {
33
52
const auto graph = R"IR(
34
53
graph(%0 : Tensor):
@@ -47,9 +66,8 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
47
66
in = at::clone (in);
48
67
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
49
68
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
50
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
51
69
52
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt , 2e-6 ));
70
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[ 0 ] , 2e-6 ));
53
71
}
54
72
55
73
TEST (Converters, ATenReshapeConvertsCorrectly) {
@@ -215,6 +233,29 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
215
233
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
216
234
}
217
235
236
+ TEST (Converters, ATenFlattenNegDimsConvertsCorrectlyWithDynamicBatch) {
237
+ const auto graph = R"IR(
238
+ graph(%0 : Tensor):
239
+ %1 : int = prim::Constant[value=-3]()
240
+ %2 : int = prim::Constant[value=-2]()
241
+ %3 : Tensor = aten::flatten(%0, %1, %2)
242
+ return (%3))IR" ;
243
+
244
+ auto g = std::make_shared<torch::jit::Graph>();
245
+ torch::jit::parseIR (graph, g.get ());
246
+
247
+ auto in = at::randint (0 , 5 , {2 , 3 , 4 }, {at::kCUDA });
248
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
249
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
250
+
251
+ in = at::clone (in);
252
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
253
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in}, true );
254
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
255
+
256
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
257
+ }
258
+
218
259
TEST (Converters, ATenTransposeConvertsCorrectly) {
219
260
const auto graph = R"IR(
220
261
graph(%x.1 : Tensor):
0 commit comments