Skip to content

Commit 362c932

Browse files
authored
Merge pull request #156 from NVIDIA/prelu
Implements prelu and a broadcasting checker
2 parents fe06d09 + c066581 commit 362c932

File tree

5 files changed

+154
-26
lines changed

5 files changed

+154
-26
lines changed

Diff for: core/conversion/converters/converters.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ struct Weights {
5555

5656
inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
5757
auto t_weights = Weights(ctx, t);
58-
return ctx->net->addConstant(t_weights.shape, t_weights.data)->getOutput(0);
58+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
59+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");
60+
const_layer->setName("[Freeze Tensor]");
61+
return const_layer->getOutput(0);
5962
}
6063

6164
} // namespace converters

Diff for: core/conversion/converters/impl/activation.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,43 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7979
new_layer->setName(util::node_info(n).c_str());
8080
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
8181

82+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
83+
return true;
84+
}
85+
}).pattern({
86+
"aten::prelu(Tensor self, Tensor weight) -> (Tensor)",
87+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
88+
auto in = args[0].ITensor();
89+
auto slopes = args[1].unwrapToTensor();
90+
91+
bool to_reshape = false;
92+
auto original_shape = in->getDimensions();
93+
if (slopes.numel() != 1 && !util::broadcastable(in->getDimensions(), util::toDims(slopes.sizes()), /*multidirectional=*/false)) {
94+
if (util::volume(in->getDimensions()) == util::volume(util::toDims(slopes.sizes()))) {
95+
to_reshape = true;
96+
LOG_DEBUG("Input shape is not broadcastable inserting shuffle layers to reshape to " << util::toDims(slopes.sizes()));
97+
auto in_shuffle = ctx->net->addShuffle(*in);
98+
TRTORCH_CHECK(in_shuffle, "Unable to create resize layer for aten::prelu input");
99+
in_shuffle->setReshapeDimensions(util::toDims(slopes.sizes()));
100+
in_shuffle->setName(std::string("[Reshape in to " + util::toStr(util::toDims(slopes.sizes())) + " for broadcasting]").c_str());
101+
in = in_shuffle->getOutput(0);
102+
}
103+
}
104+
105+
auto slope_tensor = tensor_to_const(ctx, slopes);
106+
auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor);
107+
new_layer->setName(util::node_info(n).c_str());
108+
auto out_tensor = new_layer->getOutput(0);
109+
110+
if (to_reshape) {
111+
auto out_shuffle = ctx->net->addShuffle(*out_tensor);
112+
TRTORCH_CHECK(out_shuffle, "Unable to create resize layer for aten::prelu output");
113+
out_shuffle->setReshapeDimensions(original_shape);
114+
out_shuffle->setName((std::string("[Reshape back to ") + util::toStr(original_shape) + std::string("]")).c_str());
115+
out_tensor = out_shuffle->getOutput(0);
116+
}
117+
118+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
82119
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
83120
return true;
84121
}

Diff for: core/util/trt_util.cpp

+65-25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,59 @@ namespace trtorch {
66
namespace core {
77
namespace util {
88

9+
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional) {
10+
if (a == b) {
11+
return true;
12+
}
13+
14+
if (multidirectional) {
15+
nvinfer1::Dims a_dims_eq;
16+
nvinfer1::Dims b_dims_eq;
17+
if (a.nbDims > b.nbDims) {
18+
a_dims_eq = a;
19+
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
20+
} else if (a.nbDims < b.nbDims) {
21+
a_dims_eq = toDimsPad(toVec(a), b.nbDims);
22+
b_dims_eq = b;
23+
} else {
24+
a_dims_eq = a;
25+
b_dims_eq = b;
26+
}
27+
28+
bool broadcastable = true;
29+
for (int i = 0; i < a_dims_eq.nbDims; i++) {
30+
if (b_dims_eq.d[i] == a_dims_eq.d[i] || (b_dims_eq.d[i] == 1 || a_dims_eq.d[i] == 1)) {
31+
continue;
32+
} else {
33+
broadcastable = false;
34+
break;
35+
}
36+
}
37+
return broadcastable;
38+
} else {
39+
nvinfer1::Dims b_dims_eq;
40+
if (a.nbDims > b.nbDims) {
41+
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
42+
} else if (a.nbDims < b.nbDims) {
43+
return false;
44+
} else {
45+
b_dims_eq = b;
46+
}
47+
48+
bool broadcastable = true;
49+
for (int i = 0; i < a.nbDims; i++) {
50+
if (b_dims_eq.d[i] == a.d[i] || b_dims_eq.d[i] == 1) {
51+
continue;
52+
} else {
53+
broadcastable = false;
54+
break;
55+
}
56+
}
57+
return broadcastable;
58+
}
59+
}
60+
61+
962
int64_t volume(const nvinfer1::Dims& d) {
1063
return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
1164
}
@@ -16,10 +69,7 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
1669
return toDims(l);
1770
}
1871

19-
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
20-
//TODO: Handle this with exceptions or whatever
21-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
22-
}
72+
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
2373

2474
nvinfer1::Dims dims;
2575
dims.nbDims = pad_to;
@@ -34,10 +84,8 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
3484
}
3585

3686
nvinfer1::Dims toDims(c10::IntArrayRef l) {
37-
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
38-
//TODO: Handle this with exceptions or whatever
39-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
40-
}
87+
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
88+
4189
nvinfer1::Dims dims;
4290
dims.nbDims = l.size();
4391
for (size_t i = 0; i < l.size(); i++) {
@@ -47,10 +95,8 @@ nvinfer1::Dims toDims(c10::IntArrayRef l) {
4795
}
4896

4997
nvinfer1::Dims toDims(c10::List<int64_t> l) {
50-
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
51-
//TODO: Handle this with exceptions or whatever
52-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
53-
}
98+
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
99+
54100
nvinfer1::Dims dims;
55101
dims.nbDims = l.size();
56102
for (size_t i = 0; i < l.size(); i++) {
@@ -65,10 +111,8 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
65111
return toDims(l);
66112
}
67113

68-
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
69-
//TODO: Handle this with exceptions or whatever
70-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
71-
}
114+
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
115+
72116

73117
nvinfer1::Dims dims;
74118
dims.nbDims = pad_to;
@@ -109,7 +153,7 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
109153
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
110154
// acceptable range for pos is [0, d.nbDims]
111155
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");
112-
156+
113157
nvinfer1::Dims dims;
114158

115159
int i = 0;
@@ -148,10 +192,8 @@ std::string toStr(nvinfer1::Dims d) {
148192

149193

150194
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
151-
if (l.size() != 2) {
152-
//TODO: Handle this with exceptions or whatever
153-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
154-
}
195+
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");
196+
155197
nvinfer1::DimsHW dims;
156198
dims.nbDims = l.size();
157199
for (size_t i = 0; i < l.size(); i++) {
@@ -161,10 +203,8 @@ nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
161203
}
162204

163205
nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l) {
164-
if (l.size() != 2) {
165-
//TODO: Handle this with exceptions or whatever
166-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
167-
}
206+
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");
207+
168208
nvinfer1::DimsHW dims;
169209
dims.nbDims = l.size();
170210
for (size_t i = 0; i < l.size(); i++) {

Diff for: core/util/trt_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ namespace util {
7777

7878
int64_t volume(const nvinfer1::Dims& d);
7979

80+
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional=true);
8081
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
8182
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
8283
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);

Diff for: tests/core/converters/test_activation.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
109109
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
110110
}
111111

112+
TEST(Converters, ATenPReLUConvertsCorrectly) {
113+
const auto graph = R"IR(
114+
graph(%0 : Tensor,
115+
%1 : Float(1)):
116+
%3 : Tensor = aten::prelu(%0, %1)
117+
return (%3))IR";
118+
119+
auto g = std::make_shared<torch::jit::Graph>();
120+
torch::jit::parseIR(graph, &*g);
121+
122+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
123+
auto slope = at::randint(-5, 5, {1}, {at::kCUDA});
124+
125+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
126+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
127+
128+
in = at::clone(in);
129+
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
130+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
131+
132+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
133+
}
134+
135+
TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
136+
const auto graph = R"IR(
137+
graph(%0 : Tensor,
138+
%1 : Float(10)):
139+
%3 : Tensor = aten::prelu(%0, %1)
140+
return (%3))IR";
141+
142+
auto g = std::make_shared<torch::jit::Graph>();
143+
torch::jit::parseIR(graph, &*g);
144+
145+
auto in = at::randint(-5, 5, {1,10, 1, 1}, {at::kCUDA});
146+
auto slope = at::randint(-5, 5, {10}, {at::kCUDA});
147+
148+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
149+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
150+
151+
in = at::clone(in);
152+
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
153+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
154+
155+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
156+
}
157+
158+

0 commit comments

Comments
 (0)