Skip to content

Commit c7d6b49

Browse files
committed
feat(aten::permute): Implement permute support
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 461e2ca commit c7d6b49

File tree

2 files changed

+115
-28
lines changed

2 files changed

+115
-28
lines changed

Diff for: core/conversion/converters/impl/shuffle.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
5959
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
6060
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
6161

62+
return true;
63+
}
64+
}).pattern({
65+
"aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))",
66+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
67+
auto in = args[0].ITensor();
68+
auto in_shape = util::toVec(in->getDimensions());
69+
auto new_order = args[1].unwrapToIntList().vec();
70+
71+
LOG_DEBUG("Shuffle to: " << util::toDims(new_order));
72+
73+
auto shuffle = ctx->net->addShuffle(*in);
74+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
75+
nvinfer1::Permutation permute;
76+
std::copy(new_order.begin(), new_order.end(), permute.order);
77+
shuffle->setSecondTranspose(permute);
78+
shuffle->setName(util::node_info(n).c_str());
79+
80+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
81+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
82+
6283
return true;
6384
}
6485
});

Diff for: tests/core/converters/test_shuffle.cpp

+94-28
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,87 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {
3030

3131
// TODO: IR Parser doesnt work well with neg numbers
3232
TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
33-
const auto graph = R"IR(
34-
graph(%0 : Tensor):
35-
%1 : int = prim::Constant[value=1]()
36-
%2 : int = prim::Constant[value=2]()
37-
%3 : Tensor = aten::flatten(%0, %1, %2)
38-
return (%3))IR";
33+
const auto graph = R"IR(
34+
graph(%0 : Tensor):
35+
%1 : int = prim::Constant[value=1]()
36+
%2 : int = prim::Constant[value=2]()
37+
%3 : Tensor = aten::flatten(%0, %1, %2)
38+
return (%3))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, &*g);
42+
43+
auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
44+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
45+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
46+
47+
in = at::clone(in);
48+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
49+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
50+
auto trt = trt_results[0].reshape_as(jit_results[0]);
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
53+
}
3954

40-
auto g = std::make_shared<torch::jit::Graph>();
55+
TEST(Converters, ATenReshapeConvertsCorrectly) {
56+
const auto graph = R"IR(
57+
graph(%0 : Tensor):
58+
%1 : int = prim::Constant[value=3]()
59+
%2 : int = prim::Constant[value=2]()
60+
%3 : int[] = prim::ListConstruct(%1, %2)
61+
%4 : Tensor = aten::reshape(%0, %3)
62+
return (%4))IR";
63+
64+
auto g = std::make_shared<torch::jit::Graph>();
65+
torch::jit::parseIR(graph, &*g);
66+
67+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
68+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
69+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
70+
71+
in = at::clone(in);
72+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
73+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
74+
auto trt = trt_results[0].reshape_as(jit_results[0]);
75+
76+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
77+
}
78+
79+
TEST(Converters, ATenViewConvertsCorrectly) {
80+
const auto graph = R"IR(
81+
graph(%0 : Tensor):
82+
%1 : int = prim::Constant[value=3]()
83+
%2 : int = prim::Constant[value=2]()
84+
%3 : int[] = prim::ListConstruct(%1, %2)
85+
%4 : Tensor = aten::view(%0, %3)
86+
return (%4))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
92+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
93+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
94+
95+
in = at::clone(in);
96+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
97+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
98+
auto trt = trt_results[0].reshape_as(jit_results[0]);
99+
100+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
101+
}
102+
103+
TEST(Converters, ATenPermuteConvertsCorrectly) {
104+
const auto graph = R"IR(
105+
graph(%x.1 : Tensor):
106+
%2 : int[] = prim::Constant[value=[3, 0, 1, 2]]()
107+
%3 : Tensor = aten::permute(%x.1, %2)
108+
return (%3))IR";
109+
110+
auto g = std::make_shared<torch::jit::Graph>();
41111
torch::jit::parseIR(graph, &*g);
42112

43-
auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
113+
auto in = at::randint(0, 5, {2, 3, 2, 3}, {at::kCUDA});
44114
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
45115
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
46116

@@ -52,19 +122,17 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
52122
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
53123
}
54124

55-
TEST(Converters, ATenReshapeConvertsCorrectly) {
56-
const auto graph = R"IR(
57-
graph(%0 : Tensor):
58-
%1 : int = prim::Constant[value=3]()
59-
%2 : int = prim::Constant[value=2]()
60-
%3 : int[] = prim::ListConstruct(%1, %2)
61-
%4 : Tensor = aten::reshape(%0, %3)
62-
return (%4))IR";
125+
TEST(Converters, ATenPermute3DConvertsCorrectly) {
126+
const auto graph = R"IR(
127+
graph(%x.1 : Tensor):
128+
%2 : int[] = prim::Constant[value=[0, 2, 1]]()
129+
%3 : Tensor = aten::permute(%x.1, %2)
130+
return (%3))IR";
63131

64-
auto g = std::make_shared<torch::jit::Graph>();
132+
auto g = std::make_shared<torch::jit::Graph>();
65133
torch::jit::parseIR(graph, &*g);
66134

67-
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
135+
auto in = at::randint(0, 5, {2, 2, 3}, {at::kCUDA});
68136
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
69137
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
70138

@@ -76,19 +144,17 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
76144
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
77145
}
78146

79-
TEST(Converters, ATenViewConvertsCorrectly) {
80-
const auto graph = R"IR(
81-
graph(%0 : Tensor):
82-
%1 : int = prim::Constant[value=3]()
83-
%2 : int = prim::Constant[value=2]()
84-
%3 : int[] = prim::ListConstruct(%1, %2)
85-
%4 : Tensor = aten::view(%0, %3)
86-
return (%4))IR";
147+
TEST(Converters, ATenPermute5DConvertsCorrectly) {
148+
const auto graph = R"IR(
149+
graph(%x.1 : Tensor):
150+
%2 : int[] = prim::Constant[value=[3, 4, 0, 2, 1]]()
151+
%3 : Tensor = aten::permute(%x.1, %2)
152+
return (%3))IR";
87153

88-
auto g = std::make_shared<torch::jit::Graph>();
154+
auto g = std::make_shared<torch::jit::Graph>();
89155
torch::jit::parseIR(graph, &*g);
90156

91-
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
157+
auto in = at::randint(0, 5, {2, 2, 1, 2, 3}, {at::kCUDA});
92158
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
93159
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
94160

0 commit comments

Comments
 (0)