@@ -29,6 +29,85 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
29
29
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
30
30
}
31
31
32
+ TEST (Converters, ATenCatFloatIntConvertsCorrectly) {
33
+ const auto graph = R"IR(
34
+ graph(%0 : Tensor,
35
+ %1 : Tensor):
36
+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
37
+ %3 : int = prim::Constant[value=0]()
38
+ %4 : Tensor = aten::cat(%2, %3)
39
+ return (%4))IR" ;
40
+
41
+ auto g = std::make_shared<torch::jit::Graph>();
42
+ torch::jit::parseIR (graph, g.get ());
43
+
44
+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kFloat );
45
+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
46
+
47
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
48
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
49
+
50
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
51
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
52
+
53
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
54
+ }
55
+
56
+ TEST (Converters, ATenCatIntHalfIntHalfConvertsCorrectly) {
57
+ const auto graph = R"IR(
58
+ graph(%0 : Tensor,
59
+ %1 : Tensor,
60
+ %2 : Tensor,
61
+ %3 : Tensor):
62
+ %2 : Tensor[] = prim::ListConstruct(%0, %1, %2, %3)
63
+ %3 : int = prim::Constant[value=0]()
64
+ %4 : Tensor = aten::cat(%2, %3)
65
+ return (%4))IR" ;
66
+
67
+ auto g = std::make_shared<torch::jit::Graph>();
68
+ torch::jit::parseIR (graph, g.get ());
69
+
70
+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
71
+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
72
+ auto in3 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
73
+ auto in4 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
74
+
75
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
76
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2, in3, in4});
77
+
78
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
79
+ auto trt_results =
80
+ torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2, in3, in4}, nvinfer1::DataType::kHALF );
81
+
82
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
83
+ }
84
+
85
+ TEST (Converters, ATenCatHalfIntFloatConvertsCorrectly) {
86
+ const auto graph = R"IR(
87
+ graph(%0 : Tensor,
88
+ %1 : Tensor,
89
+ %2 : Tensor):
90
+ %2 : Tensor[] = prim::ListConstruct(%0, %1, %2)
91
+ %3 : int = prim::Constant[value=0]()
92
+ %4 : Tensor = aten::cat(%2, %3)
93
+ return (%4))IR" ;
94
+
95
+ auto g = std::make_shared<torch::jit::Graph>();
96
+ torch::jit::parseIR (graph, g.get ());
97
+
98
+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
99
+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
100
+ auto in3 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kFloat );
101
+
102
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
103
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2, in3});
104
+
105
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
106
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2, in3});
107
+
108
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
109
+ }
110
+
32
111
TEST (Converters, ATenCatDiffTensorConvertsCorrectly) {
33
112
const auto graph = R"IR(
34
113
graph(%0 : Tensor,
0 commit comments