Skip to content

Commit d8ca182

Browse files
committed
fix(aten::cat): support neg dim for cat
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 114969b commit d8ca182

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

Diff for: core/conversion/converters/impl/concat.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@ auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
2020
for (auto t : ts) {
2121
if (t.isTensor()) {
2222
auto torch_tensor = t.toTensor();
23-
auto t_weights = Weights(ctx, torch_tensor);
24-
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
25-
tensors.push_back(const_layer->getOutput(0));
23+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
2624
} else {
2725
auto cont = t.toCustomClass<TensorContainer>();
2826
tensors.push_back(cont->tensor());
2927
}
3028
}
3129

30+
if (dim < 0) {
31+
dim = tensors[0]->getDimensions().nbDims + dim;
32+
}
33+
3234
auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
3335
cat_layer->setAxis(static_cast<int>(dim));
3436
auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0));

Diff for: tests/core/conversion/converters/test_concat.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,52 @@ TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
4949
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
5050
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
5151

52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
}
54+
TEST(Converters, ATenCatPureTensorNegDimConvertsCorrectly) {
55+
const auto graph = R"IR(
56+
graph(%0 : Tensor,
57+
%1 : Tensor):
58+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
59+
%3 : int = prim::Constant[value=-1]()
60+
%4 : Tensor = aten::cat(%2, %3)
61+
return (%4))IR";
62+
63+
auto g = std::make_shared<torch::jit::Graph>();
64+
torch::jit::parseIR(graph, g.get());
65+
66+
auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA});
67+
auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA});
68+
69+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
70+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
71+
72+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
73+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
74+
75+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
76+
}
77+
78+
TEST(Converters, ATenCatDiffTensorNegDimConvertsCorrectly) {
79+
const auto graph = R"IR(
80+
graph(%0 : Tensor,
81+
%1 : Float(5)):
82+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
83+
%3 : int = prim::Constant[value=-1]()
84+
%4 : Tensor = aten::cat(%2, %3)
85+
return (%4))IR";
86+
87+
auto g = std::make_shared<torch::jit::Graph>();
88+
torch::jit::parseIR(graph, g.get());
89+
90+
auto in1 = at::randint(1, 10, {5, 5}, {at::kCUDA});
91+
auto in2 = at::randint(1, 10, {5, 5}, {at::kCUDA});
92+
93+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
94+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
95+
96+
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
97+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
98+
5299
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53100
}

0 commit comments

Comments
 (0)