@@ -18,8 +18,7 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
18
18
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
19
19
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
20
20
21
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
22
- jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), THRESHOLD_E5));
21
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
23
22
};
24
23
const auto graph = R"IR(
25
24
graph(%0 : Tensor,
@@ -35,9 +34,17 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
35
34
%3 : int = prim::Constant[value=-1]()
36
35
%4 : Tensor = aten::stack(%2, %3)
37
36
return (%4))IR" ;
37
+ const auto graph3 = R"IR(
38
+ graph(%0 : Tensor,
39
+ %1 : Tensor):
40
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
41
+ %3 : int = prim::Constant[value=-2]()
42
+ %4 : Tensor = aten::stack(%2, %3)
43
+ return (%4))IR" ;
38
44
39
45
TestATenStackPureTensorConvertsCorrectly (graph);
40
46
TestATenStackPureTensorConvertsCorrectly (graph2);
47
+ TestATenStackPureTensorConvertsCorrectly (graph3);
41
48
}
42
49
43
50
TEST (Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
@@ -89,8 +96,7 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
89
96
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2});
90
97
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1});
91
98
92
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
93
- jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), THRESHOLD_E5));
99
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
94
100
};
95
101
const auto graph = R"IR(
96
102
graph(%0 : Tensor,
@@ -106,6 +112,14 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
106
112
%3 : int = prim::Constant[value=-1]()
107
113
%4 : Tensor = aten::stack(%2, %3)
108
114
return (%4))IR" ;
115
+ const auto graph3 = R"IR(
116
+ graph(%0 : Tensor,
117
+ %1 : Float(4, 4, 4, strides=[16, 4, 1])):
118
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
119
+ %3 : int = prim::Constant[value=-3]()
120
+ %4 : Tensor = aten::stack(%2, %3)
121
+ return (%4))IR" ;
109
122
TestATenStackDiffTensorConvertsCorrectly (graph);
110
123
TestATenStackDiffTensorConvertsCorrectly (graph2);
124
+ TestATenStackDiffTensorConvertsCorrectly (graph3);
111
125
}
0 commit comments