Skip to content

Commit 65feab1

Browse files
fix: Support non -1 end idx and <0 start idx in aten::flatten converter (#2321)
1 parent ff4d940 commit 65feab1

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

core/conversion/converters/impl/shuffle.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
2020
auto in_shape = util::toVec(in->getDimensions());
2121
std::vector<int64_t> out_shape;
2222
if (ctx->input_is_dynamic) {
23-
end_dim = (end_dim == -1) ? in_shape.size() - 1 : end_dim;
23+
if (start_dim < 0) {
24+
start_dim = start_dim + in_shape.size();
25+
}
26+
if (end_dim < 0) {
27+
end_dim = end_dim + in_shape.size();
28+
}
2429
int nbDynamicFlattenedDims = 0;
2530
int nbDynamicUnflattenedDims = 0;
2631
for (int i = 0; i < (int)in_shape.size(); i++) {

tests/core/conversion/converters/test_shuffle.cpp

+47-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
66

7-
// TODO: IR Parser doesnt work well with neg numbers
87
TEST(Converters, ATenFlattenConvertsCorrectly) {
98
const auto graph = R"IR(
109
graph(%0 : Tensor):
@@ -23,12 +22,32 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {
2322
in = at::clone(in);
2423
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
2524
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
26-
auto trt = trt_results[0].reshape_as(jit_results[0]);
2725

28-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
26+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
27+
}
28+
29+
TEST(Converters, ATenFlattenNegDimsConvertsCorrectly) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor):
32+
%1 : int = prim::Constant[value=-3]()
33+
%2 : int = prim::Constant[value=-2]()
34+
%3 : Tensor = aten::flatten(%0, %1, %2)
35+
return (%3))IR";
36+
37+
auto g = std::make_shared<torch::jit::Graph>();
38+
torch::jit::parseIR(graph, g.get());
39+
40+
auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
41+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
42+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
43+
44+
in = at::clone(in);
45+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
46+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
47+
48+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
2949
}
3050

31-
// TODO: IR Parser doesnt work well with neg numbers
3251
TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
3352
const auto graph = R"IR(
3453
graph(%0 : Tensor):
@@ -47,9 +66,8 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
4766
in = at::clone(in);
4867
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
4968
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
50-
auto trt = trt_results[0].reshape_as(jit_results[0]);
5169

52-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
70+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
5371
}
5472

5573
TEST(Converters, ATenReshapeConvertsCorrectly) {
@@ -215,6 +233,29 @@ TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicBatch) {
215233
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
216234
}
217235

236+
TEST(Converters, ATenFlattenNegDimsConvertsCorrectlyWithDynamicBatch) {
237+
const auto graph = R"IR(
238+
graph(%0 : Tensor):
239+
%1 : int = prim::Constant[value=-3]()
240+
%2 : int = prim::Constant[value=-2]()
241+
%3 : Tensor = aten::flatten(%0, %1, %2)
242+
return (%3))IR";
243+
244+
auto g = std::make_shared<torch::jit::Graph>();
245+
torch::jit::parseIR(graph, g.get());
246+
247+
auto in = at::randint(0, 5, {2, 3, 4}, {at::kCUDA});
248+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
249+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
250+
251+
in = at::clone(in);
252+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
253+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
254+
auto trt = trt_results[0].reshape_as(jit_results[0]);
255+
256+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
257+
}
258+
218259
TEST(Converters, ATenTransposeConvertsCorrectly) {
219260
const auto graph = R"IR(
220261
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)