Skip to content

Commit 353f2d2

Browse files
committed
feat(//core/conversion/converters/impl/shuffle): Implement aten::resize
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a51c7b6 commit 353f2d2

File tree

8 files changed

+95
-8
lines changed

8 files changed

+95
-8
lines changed

Diff for: core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_library(
1717
"impl/linear.cpp",
1818
"impl/pooling.cpp",
1919
"impl/reduce.cpp",
20+
"impl/shuffle.cpp",
2021
"impl/softmax.cpp",
2122
"impl/unary.cpp",
2223
],

Diff for: core/conversion/converters/converters.h

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "torch/csrc/jit/runtime/custom_operator.h"
77
#include "ATen/core/function_schema.h"
88

9+
#include "core/util/prelude.h"
910
#include "core/conversion/conversionctx/ConversionCtx.h"
1011

1112
namespace trtorch {

Diff for: core/conversion/converters/impl/shuffle.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "core/conversion/converters/converters.h"
2+
3+
namespace trtorch {
4+
namespace core {
5+
namespace conversion {
6+
namespace converters {
7+
namespace impl {
8+
namespace {
9+
10+
static auto shuffle_registrations = RegisterNodeConversionPatterns()
11+
.pattern({
12+
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
13+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+
auto in = args[0].ITensor();
15+
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);
16+
17+
auto shuffle = ctx->net->addShuffle(*in);
18+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
19+
shuffle->setReshapeDimensions(new_shape);
20+
shuffle->setName(util::node_info(n).c_str());
21+
22+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
23+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24+
25+
return true;
26+
}
27+
});
28+
} // namespace
29+
} // namespace impl
30+
} // namespace converters
31+
} // namespace conversion
32+
} // namespace core
33+
} // namespace trtorch

Diff for: core/conversion/converters/impl/softmax.cpp

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "core/util/prelude.h"
21
#include "core/conversion/converters/converters.h"
32

43
namespace trtorch {
@@ -29,12 +28,7 @@ static auto softmax_registrations = RegisterNodeConversionPatterns()
2928
auto softmax = ctx->net->addSoftMax(*in);
3029

3130
TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n);
32-
33-
if (!softmax) {
34-
LOG_ERROR("Unable to create softmax layer from node: " << *n);
35-
return false;
36-
}
37-
LOG_WARNING("Disregarding dtype argument, please verify");
31+
LOG_DEBUG("Disregarding dtype argument");
3832

3933
if (shape.size() > 3) {
4034
softmax->setAxes(1 << (dim));
@@ -69,4 +63,4 @@ static auto softmax_registrations = RegisterNodeConversionPatterns()
6963
} // namespace converters
7064
} // namespace conversion
7165
} // namespace core
72-
} // trtorch
66+
} // namespace trtorch

Diff for: core/util/trt_util.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,29 @@ nvinfer1::Dims toDims(c10::List<int64_t> l) {
5959
return dims;
6060
}
6161

62+
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
63+
if (l.size() > pad_to) {
64+
LOG_DEBUG("Requested padding of dimensions to " << pad_to << " but found " << l.size() << " dimensions, not going to pad");
65+
return toDims(l);
66+
}
67+
68+
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
69+
//TODO: Handle this with exceptions or whatever
70+
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
71+
}
72+
73+
nvinfer1::Dims dims;
74+
dims.nbDims = pad_to;
75+
for (size_t i = 0; i < pad_to - l.size(); i++) {
76+
dims.d[i] = 1;
77+
}
78+
79+
for (size_t i = pad_to - l.size(); i < pad_to; i++) {
80+
dims.d[i] = l[i - (pad_to - l.size())];
81+
}
82+
return dims;
83+
}
84+
6285
std::vector<int64_t> toVec(nvinfer1::Dims d) {
6386
std::vector<int64_t> dims;
6487
for (int i = 0; i < d.nbDims; i++) {

Diff for: core/util/trt_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ namespace util {
7878
int64_t volume(const nvinfer1::Dims& d);
7979

8080
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
81+
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
8182
nvinfer1::Dims toDims(c10::IntArrayRef l);
8283
nvinfer1::Dims toDims(c10::List<int64_t> l);
8384
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l);

Diff for: tests/core/converters/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ converter_test(
44
name = "test_softmax"
55
)
66

7+
converter_test(
8+
name = "test_shuffle"
9+
)
10+
711
converter_test(
812
name = "test_activation"
913
)
@@ -36,6 +40,7 @@ test_suite(
3640
name = "test_converters",
3741
tests = [
3842
":test_softmax",
43+
":test_shuffle",
3944
":test_activation",
4045
":test_pooling",
4146
":test_unary",

Diff for: tests/core/converters/test_shuffle.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenReshapeConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=3]()
11+
%2 : int = prim::Constant[value=2]()
12+
%3 : int[] = prim::ListConstruct(%1, %2)
13+
%4 : Tensor = aten::reshape(%0, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
18+
19+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
20+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
21+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
22+
23+
in = at::clone(in);
24+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
26+
auto trt = trt_results[0].reshape_as(jit_results[0]);
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
29+
}

0 commit comments

Comments
 (0)