@@ -49,5 +49,52 @@ TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
49
49
params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
50
50
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
51
51
52
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53
+ }
54
+ TEST (Converters, ATenCatPureTensorNegDimConvertsCorrectly) {
55
+ const auto graph = R"IR(
56
+ graph(%0 : Tensor,
57
+ %1 : Tensor):
58
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
59
+ %3 : int = prim::Constant[value=-1]()
60
+ %4 : Tensor = aten::cat(%2, %3)
61
+ return (%4))IR" ;
62
+
63
+ auto g = std::make_shared<torch::jit::Graph>();
64
+ torch::jit::parseIR (graph, g.get ());
65
+
66
+ auto in1 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
67
+ auto in2 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
68
+
69
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
70
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1, in2});
71
+
72
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
73
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1, in2});
74
+
75
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
76
+ }
77
+
78
+ TEST (Converters, ATenCatDiffTensorNegDimConvertsCorrectly) {
79
+ const auto graph = R"IR(
80
+ graph(%0 : Tensor,
81
+ %1 : Float(5)):
82
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
83
+ %3 : int = prim::Constant[value=-1]()
84
+ %4 : Tensor = aten::cat(%2, %3)
85
+ return (%4))IR" ;
86
+
87
+ auto g = std::make_shared<torch::jit::Graph>();
88
+ torch::jit::parseIR (graph, g.get ());
89
+
90
+ auto in1 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
91
+ auto in2 = at::randint (1 , 10 , {5 , 5 }, {at::kCUDA });
92
+
93
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
94
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in1});
95
+
96
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {in2});
97
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in1});
98
+
52
99
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
53
100
}
0 commit comments