Skip to content

Commit c2d3a6e

Browse files
committed
feat(aten::cat): Implements aten::cat and completes support for SSD
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 619e345 commit c2d3a6e

File tree

5 files changed

+112
-3
lines changed

5 files changed

+112
-3
lines changed

Diff for: core/conversion/converters/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ cc_library(
3333
deps = [
3434
"@tensorrt//:nvinfer",
3535
"//core/util:prelude",
36-
"//core/conversion/arg",
36+
"//core/conversion/var",
37+
"//core/conversion/tensorcontainer",
3738
"//core/conversion/conversionctx",
3839
] + select({
3940
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],

Diff for: core/conversion/converters/impl/concat.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "core/util/prelude.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
auto cat_registrations = RegisterNodeConversionPatterns()
12+
.pattern({
13+
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
14+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
auto ts = args[0].IValue()->toListRef();
16+
auto dim = args[1].unwrapToInt();
17+
18+
std::vector<nvinfer1::ITensor*> tensors;
19+
for (auto t : ts) {
20+
std::cout << t << std::endl;
21+
if (t.isTensor()) {
22+
auto torch_tensor = t.toTensor();
23+
auto t_weights = Weights(ctx, torch_tensor);
24+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
25+
tensors.push_back(const_layer->getOutput(0));
26+
} else {
27+
auto cont = t.toCustomClass<TensorContainer>();
28+
tensors.push_back(cont->tensor());
29+
}
30+
}
31+
32+
auto cat_layer = ctx->net->addConcatenation(tensors.data(), tensors.size());
33+
cat_layer->setAxis(static_cast<int>(dim));
34+
auto cat_out = ctx->AssociateValueAndTensor(n->outputs()[0], cat_layer->getOutput(0));
35+
36+
LOG_DEBUG("Output tensor shape: " << cat_out->getDimensions());
37+
38+
return true;
39+
}
40+
});
41+
} // namespace
42+
} // namespace impl
43+
} // namespace converters
44+
} // namespace conversion
45+
} // namespace core
46+
} // namespace trtorch
47+

Diff for: core/conversion/evaluators/prim.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ auto prim_registrations = RegisterNodeEvaluators()
7474
auto list = c10::impl::GenericList(elementType);
7575
list.reserve(num_inputs);
7676
for (auto in : n->inputs()) {
77-
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
78-
list.emplace_back(std::move(x));
77+
if (args.at(in).isITensor()) {
78+
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
79+
list.emplace_back(std::move(x));
80+
} else {
81+
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
82+
}
7983
}
8084
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
8185
}

Diff for: tests/core/converters/BUILD

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm"
1616
)
1717

18+
converter_test(
19+
name = "test_concat"
20+
)
21+
1822
converter_test(
1923
name = "test_conv_deconv"
2024
)

Diff for: tests/core/converters/test_concat.cpp

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%1 : Tensor):
11+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
12+
%3 : int = prim::Constant[value=0]()
13+
%4 : Tensor = aten::cat(%2, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
18+
19+
auto in1 = at::randint(1, 10, {5}, {at::kCUDA});
20+
auto in2 = at::randint(1, 10, {5}, {at::kCUDA});
21+
22+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
24+
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
29+
}
30+
31+
TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor,
34+
%1 : Float(5)):
35+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
36+
%3 : int = prim::Constant[value=0]()
37+
%4 : Tensor = aten::cat(%2, %3)
38+
return (%4))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, &*g);
42+
43+
auto in1 = at::randint(1, 10, {5}, {at::kCUDA});
44+
auto in2 = at::randint(1, 10, {5}, {at::kCUDA});
45+
46+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
47+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
48+
49+
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
50+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
}

0 commit comments

Comments
 (0)