Skip to content

Commit c60070b

Browse files
feat: support type promotion in aten::cat converter (#1911)
1 parent 520f62a commit c60070b

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

core/conversion/converters/converter_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ nvinfer1::ITensor* get_slice_size(
9696

9797
nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);
9898

99+
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b);
100+
99101
} // namespace converters
100102
} // namespace conversion
101103
} // namespace core

core/conversion/converters/impl/concat.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/converters/converter_util.h"
12
#include "core/conversion/converters/converters.h"
23
#include "core/conversion/tensorcontainer/TensorContainer.h"
34
#include "core/util/prelude.h"
@@ -27,6 +28,17 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
2728
}
2829
}
2930

31+
auto promo_dtype = tensors[0]->getType();
32+
for(size_t idx = 1UL; idx < tensors.size(); ++idx){
33+
promo_dtype = promote_types(promo_dtype, tensors[idx]->getType());
34+
}
35+
36+
for(size_t idx = 0UL; idx < tensors.size(); ++idx){
37+
if(tensors[idx]->getType() != promo_dtype){
38+
tensors[idx] = castITensor(ctx, tensors[idx], promo_dtype, util::node_info(n) + "_cast_" + std::to_string(idx));
39+
}
40+
}
41+
3042
if (dim < 0) {
3143
dim = tensors[0]->getDimensions().nbDims + dim;
3244
}

tests/core/conversion/converters/test_concat.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,85 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
2929
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3030
}
3131

32+
TEST(Converters, ATenCatFloatIntConvertsCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%0 : Tensor,
35+
%1 : Tensor):
36+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
37+
%3 : int = prim::Constant[value=0]()
38+
%4 : Tensor = aten::cat(%2, %3)
39+
return (%4))IR";
40+
41+
auto g = std::make_shared<torch::jit::Graph>();
42+
torch::jit::parseIR(graph, g.get());
43+
44+
auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat);
45+
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
46+
47+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
49+
50+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
52+
53+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
54+
}
55+
56+
TEST(Converters, ATenCatIntHalfIntHalfConvertsCorrectly) {
57+
const auto graph = R"IR(
58+
graph(%0 : Tensor,
59+
%1 : Tensor,
60+
%2 : Tensor,
61+
%3 : Tensor):
62+
%2 : Tensor[] = prim::ListConstruct(%0, %1, %2, %3)
63+
%3 : int = prim::Constant[value=0]()
64+
%4 : Tensor = aten::cat(%2, %3)
65+
return (%4))IR";
66+
67+
auto g = std::make_shared<torch::jit::Graph>();
68+
torch::jit::parseIR(graph, g.get());
69+
70+
auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
71+
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);
72+
auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
73+
auto in4 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);
74+
75+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
76+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3, in4});
77+
78+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
79+
auto trt_results =
80+
torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3, in4}, nvinfer1::DataType::kHALF);
81+
82+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
83+
}
84+
85+
TEST(Converters, ATenCatHalfIntFloatConvertsCorrectly) {
86+
const auto graph = R"IR(
87+
graph(%0 : Tensor,
88+
%1 : Tensor,
89+
%2 : Tensor):
90+
%2 : Tensor[] = prim::ListConstruct(%0, %1, %2)
91+
%3 : int = prim::Constant[value=0]()
92+
%4 : Tensor = aten::cat(%2, %3)
93+
return (%4))IR";
94+
95+
auto g = std::make_shared<torch::jit::Graph>();
96+
torch::jit::parseIR(graph, g.get());
97+
98+
auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
99+
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);
100+
auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat);
101+
102+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
103+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3});
104+
105+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
106+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3});
107+
108+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
109+
}
110+
32111
TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
33112
const auto graph = R"IR(
34113
graph(%0 : Tensor,

0 commit comments

Comments
 (0)