@@ -40,6 +40,41 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
40
40
TestATenStackPureTensorConvertsCorrectly (graph2);
41
41
}
42
42
43
+ TEST (Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
44
+ auto TestATenStackPureTensorConvertsCorrectly = [](const std::string& graph) {
45
+ auto g = std::make_shared<torch::jit::Graph>();
46
+ torch::jit::parseIR (graph, g.get ());
47
+
48
+ auto in1 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
49
+ auto in2 = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
50
+
51
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
52
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
53
+
54
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
55
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic (g, params, {in1, in2});
56
+
57
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
58
+ };
59
+ const auto graph = R"IR(
60
+ graph(%0 : Tensor,
61
+ %1 : Tensor):
62
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
63
+ %3 : int = prim::Constant[value=1]()
64
+ %4 : Tensor = aten::stack(%2, %3)
65
+ return (%4))IR" ;
66
+ const auto graph2 = R"IR(
67
+ graph(%0 : Tensor,
68
+ %1 : Tensor):
69
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
70
+ %3 : int = prim::Constant[value=-1]()
71
+ %4 : Tensor = aten::stack(%2, %3)
72
+ return (%4))IR" ;
73
+
74
+ TestATenStackPureTensorConvertsCorrectly (graph);
75
+ TestATenStackPureTensorConvertsCorrectly (graph2);
76
+ }
77
+
43
78
TEST (Converters, ATenStackDiffTensorConvertsCorrectly) {
44
79
auto TestATenStackDiffTensorConvertsCorrectly = [](const std::string& graph) {
45
80
auto g = std::make_shared<torch::jit::Graph>();
0 commit comments