Skip to content

Commit 4eb20bb

Browse files
committed
fix(aten::flatten): Fixes dynamic shape for flatten
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8d5b123 commit 4eb20bb

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

Diff for: core/conversion/converters/impl/shuffle.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
1717
auto start_dim = args[1].unwrapToInt();
1818
auto end_dim = args[2].unwrapToInt();
1919
auto in_shape = util::toVec(in->getDimensions());
20-
auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes();
20+
std::vector<int64_t> out_shape;
21+
if (ctx->input_is_dynamic) {
22+
out_shape = std::vector<int64_t>({in_shape[0], -1});
23+
} else {
24+
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
25+
}
2126

2227
auto shuffle = ctx->net->addShuffle(*in);
2328
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
@@ -33,7 +38,12 @@ static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
3338
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3439
auto in = args[0].ITensor();
3540
auto in_shape = util::toVec(in->getDimensions());
36-
auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes();
41+
std::vector<int64_t> new_shape;
42+
if (ctx->input_is_dynamic) {
43+
TRTORCH_THROW_ERROR("Resize is currently not support in dynamic input shape compilation");
44+
} else {
45+
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
46+
}
3747

3848
auto shuffle = ctx->net->addShuffle(*in);
3949
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);

Diff for: tests/core/converters/test_shuffle.cpp

+41-18
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66

77
// TODO: IR Parser doesnt work well with neg numbers
88
TEST(Converters, ATenFlattenConvertsCorrectly) {
9-
const auto graph = R"IR(
10-
graph(%0 : Tensor):
11-
%1 : int = prim::Constant[value=0]()
12-
%2 : int = prim::Constant[value=1]()
13-
%3 : Tensor = aten::flatten(%0, %1, %2)
14-
return (%3))IR";
15-
16-
auto g = std::make_shared<torch::jit::Graph>();
17-
torch::jit::parseIR(graph, &*g);
9+
const auto graph = R"IR(
10+
graph(%0 : Tensor):
11+
%1 : int = prim::Constant[value=0]()
12+
%2 : int = prim::Constant[value=1]()
13+
%3 : Tensor = aten::flatten(%0, %1, %2)
14+
return (%3))IR";
1815

19-
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
20-
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
21-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
2218

23-
in = at::clone(in);
24-
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
26-
auto trt = trt_results[0].reshape_as(jit_results[0]);
19+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
20+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
21+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
2722

28-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
23+
in = at::clone(in);
24+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
26+
auto trt = trt_results[0].reshape_as(jit_results[0]);
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
2929
}
3030

3131
// TODO: IR Parser doesnt work well with neg numbers
@@ -164,4 +164,27 @@ TEST(Converters, ATenPermute5DConvertsCorrectly) {
164164
auto trt = trt_results[0].reshape_as(jit_results[0]);
165165

166166
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
167-
}
167+
}
168+
169+
TEST(Converters, ATenFlattenConvertsCorrectlyWithDynamicInput) {
170+
const auto graph = R"IR(
171+
graph(%0 : Tensor):
172+
%1 : int = prim::Constant[value=0]()
173+
%2 : int = prim::Constant[value=1]()
174+
%3 : Tensor = aten::flatten(%0, %1, %2)
175+
return (%3))IR";
176+
177+
auto g = std::make_shared<torch::jit::Graph>();
178+
torch::jit::parseIR(graph, &*g);
179+
180+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
181+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
182+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
183+
184+
in = at::clone(in);
185+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
186+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in});
187+
auto trt = trt_results[0].reshape_as(jit_results[0]);
188+
189+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
190+
}

0 commit comments

Comments
 (0)