Skip to content

Commit 847ebff

Browse files
mfeliz-cruisenarendasan
authored andcommitted
Support dims < -1 in aten::stack converter (#150)
* Handle dim < -1 in aten::stack * Remove reshapes
1 parent 2e58a9e commit 847ebff

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

core/conversion/converters/impl/stack.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
1919
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
2020
auto in = args[0].IValue()->toListRef();
2121
auto dim = args[1].unwrapToInt();
22-
if (-1 == dim) {
22+
if (dim < 0) {
2323
auto first_in = in[0];
2424
if (first_in.isTensor()) {
25-
dim = first_in.toTensor().ndimension();
25+
dim = first_in.toTensor().ndimension() + dim + 1;
2626
} else {
27-
dim = first_in.toCustomClass<TensorContainer>()->tensor()->getDimensions().nbDims;
27+
dim = first_in.toCustomClass<TensorContainer>()->tensor()->getDimensions().nbDims + dim + 1;
2828
}
2929
}
3030

tests/core/conversion/converters/test_stack.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
1818
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1919
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
2020

21-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
22-
jit_results[0], trt_results[0].reshape_as(jit_results[0]), THRESHOLD_E5));
21+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
2322
};
2423
const auto graph = R"IR(
2524
graph(%0 : Tensor,
@@ -35,9 +34,17 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
3534
%3 : int = prim::Constant[value=-1]()
3635
%4 : Tensor = aten::stack(%2, %3)
3736
return (%4))IR";
37+
const auto graph3 = R"IR(
38+
graph(%0 : Tensor,
39+
%1 : Tensor):
40+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
41+
%3 : int = prim::Constant[value=-2]()
42+
%4 : Tensor = aten::stack(%2, %3)
43+
return (%4))IR";
3844

3945
TestATenStackPureTensorConvertsCorrectly(graph);
4046
TestATenStackPureTensorConvertsCorrectly(graph2);
47+
TestATenStackPureTensorConvertsCorrectly(graph3);
4148
}
4249

4350
TEST(Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
@@ -89,8 +96,7 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
8996
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {in2});
9097
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
9198

92-
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(
93-
jit_results[0], trt_results[0].reshape_as(jit_results[0]), THRESHOLD_E5));
99+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], THRESHOLD_E5));
94100
};
95101
const auto graph = R"IR(
96102
graph(%0 : Tensor,
@@ -106,6 +112,14 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
106112
%3 : int = prim::Constant[value=-1]()
107113
%4 : Tensor = aten::stack(%2, %3)
108114
return (%4))IR";
115+
const auto graph3 = R"IR(
116+
graph(%0 : Tensor,
117+
%1 : Float(4, 4, 4, strides=[16, 4, 1])):
118+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
119+
%3 : int = prim::Constant[value=-3]()
120+
%4 : Tensor = aten::stack(%2, %3)
121+
return (%4))IR";
109122
TestATenStackDiffTensorConvertsCorrectly(graph);
110123
TestATenStackDiffTensorConvertsCorrectly(graph2);
124+
TestATenStackDiffTensorConvertsCorrectly(graph3);
111125
}

0 commit comments

Comments
 (0)