Skip to content

Commit 09e15b2

Browse files
authored
Merge pull request #1804 from mfeliz-cruise/michael.feliz/dynamic_aten_stack
[fix] aten::stack with dynamic inputs
2 parents 34a3328 + 285913d commit 09e15b2

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

core/conversion/converters/impl/stack.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
4343
auto cont = t.toCustomClass<TensorContainer>();
4444
itensor = cont->tensor();
4545
}
46-
4746
auto shuffle_layer = ctx->net->addShuffle(*itensor);
4847
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
49-
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim));
48+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(itensor->getDimensions(), dim, 1, false));
5049

5150
tensors.push_back(shuffle_layer->getOutput(0));
5251
}

tests/core/conversion/converters/test_stack.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,41 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
4040
TestATenStackPureTensorConvertsCorrectly(graph2);
4141
}
4242

43+
TEST(Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
44+
auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
45+
auto g = std::make_shared<torch::jit::Graph>();
46+
torch::jit::parseIR(graph, g.get());
47+
48+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
49+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
50+
51+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
52+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
53+
54+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
55+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in1, in2});
56+
57+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
58+
};
59+
const auto graph = R"IR(
60+
graph(%0 : Tensor,
61+
%1 : Tensor):
62+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
63+
%3 : int = prim::Constant[value=1]()
64+
%4 : Tensor = aten::stack(%2, %3)
65+
return (%4))IR";
66+
const auto graph2 = R"IR(
67+
graph(%0 : Tensor,
68+
%1 : Tensor):
69+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
70+
%3 : int = prim::Constant[value=-1]()
71+
%4 : Tensor = aten::stack(%2, %3)
72+
return (%4))IR";
73+
74+
TestATenStackPureTensorConvertsCorrectly(graph);
75+
TestATenStackPureTensorConvertsCorrectly(graph2);
76+
}
77+
4378
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
4479
auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
4580
auto g = std::make_shared<torch::jit::Graph>();

0 commit comments

Comments
 (0)