Skip to content

Commit c066581

Browse files
committed
feat(aten::prelu): Implement the multi-channel version of prelu and
broadcasting checks Signed-off-byL Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8bc4369 commit c066581

File tree

4 files changed

+93
-33
lines changed

4 files changed

+93
-33
lines changed

Diff for: core/conversion/converters/converters.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ struct Weights {
5555

5656
inline nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t) {
5757
auto t_weights = Weights(ctx, t);
58-
return ctx->net->addConstant(t_weights.shape, t_weights.data)->getOutput(0);
58+
auto const_layer = ctx->net->addConstant(t_weights.shape, t_weights.data);
59+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor");
60+
const_layer->setName("[Freeze Tensor]");
61+
return const_layer->getOutput(0);
5962
}
6063

6164
} // namespace converters

Diff for: core/conversion/converters/impl/activation.cpp

+23-7
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,34 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
8888
auto in = args[0].ITensor();
8989
auto slopes = args[1].unwrapToTensor();
9090

91-
//if (slopes.numel() != 1) {
92-
// auto in_dims = util::toVec(in.getDimensions());
93-
// auto per_channel_shape = std::vector<int64_t>(in_dims.begin() + 2, in_dims.end());
94-
// for ()
95-
//}
91+
bool to_reshape = false;
92+
auto original_shape = in->getDimensions();
93+
if (slopes.numel() != 1 && !util::broadcastable(in->getDimensions(), util::toDims(slopes.sizes()), /*multidirectional=*/false)) {
94+
if (util::volume(in->getDimensions()) == util::volume(util::toDims(slopes.sizes()))) {
95+
to_reshape = true;
96+
LOG_DEBUG("Input shape is not broadcastable inserting shuffle layers to reshape to " << util::toDims(slopes.sizes()));
97+
auto in_shuffle = ctx->net->addShuffle(*in);
98+
TRTORCH_CHECK(in_shuffle, "Unable to create resize layer for aten::prelu input");
99+
in_shuffle->setReshapeDimensions(util::toDims(slopes.sizes()));
100+
in_shuffle->setName(std::string("[Reshape in to " + util::toStr(util::toDims(slopes.sizes())) + " for broadcasting]").c_str());
101+
in = in_shuffle->getOutput(0);
102+
}
103+
}
96104

97105
auto slope_tensor = tensor_to_const(ctx, slopes);
98-
99106
auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor);
100107
new_layer->setName(util::node_info(n).c_str());
101-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
108+
auto out_tensor = new_layer->getOutput(0);
109+
110+
if (to_reshape) {
111+
auto out_shuffle = ctx->net->addShuffle(*out_tensor);
112+
TRTORCH_CHECK(out_shuffle, "Unable to create resize layer for aten::prelu output");
113+
out_shuffle->setReshapeDimensions(original_shape);
114+
out_shuffle->setName((std::string("[Reshape back to ") + util::toStr(original_shape) + std::string("]")).c_str());
115+
out_tensor = out_shuffle->getOutput(0);
116+
}
102117

118+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
103119
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
104120
return true;
105121
}

Diff for: core/util/trt_util.cpp

+65-25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,59 @@ namespace trtorch {
66
namespace core {
77
namespace util {
88

9+
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional) {
10+
if (a == b) {
11+
return true;
12+
}
13+
14+
if (multidirectional) {
15+
nvinfer1::Dims a_dims_eq;
16+
nvinfer1::Dims b_dims_eq;
17+
if (a.nbDims > b.nbDims) {
18+
a_dims_eq = a;
19+
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
20+
} else if (a.nbDims < b.nbDims) {
21+
a_dims_eq = toDimsPad(toVec(a), b.nbDims);
22+
b_dims_eq = b;
23+
} else {
24+
a_dims_eq = a;
25+
b_dims_eq = b;
26+
}
27+
28+
bool broadcastable = true;
29+
for (int i = 0; i < a_dims_eq.nbDims; i++) {
30+
if (b_dims_eq.d[i] == a_dims_eq.d[i] || (b_dims_eq.d[i] == 1 || a_dims_eq.d[i] == 1)) {
31+
continue;
32+
} else {
33+
broadcastable = false;
34+
break;
35+
}
36+
}
37+
return broadcastable;
38+
} else {
39+
nvinfer1::Dims b_dims_eq;
40+
if (a.nbDims > b.nbDims) {
41+
b_dims_eq = toDimsPad(toVec(b), a.nbDims);
42+
} else if (a.nbDims < b.nbDims) {
43+
return false;
44+
} else {
45+
b_dims_eq = b;
46+
}
47+
48+
bool broadcastable = true;
49+
for (int i = 0; i < a.nbDims; i++) {
50+
if (b_dims_eq.d[i] == a.d[i] || b_dims_eq.d[i] == 1) {
51+
continue;
52+
} else {
53+
broadcastable = false;
54+
break;
55+
}
56+
}
57+
return broadcastable;
58+
}
59+
}
60+
61+
962
int64_t volume(const nvinfer1::Dims& d) {
1063
return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
1164
}
@@ -16,10 +69,7 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
1669
return toDims(l);
1770
}
1871

19-
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
20-
//TODO: Handle this with exceptions or whatever
21-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
22-
}
72+
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
2373

2474
nvinfer1::Dims dims;
2575
dims.nbDims = pad_to;
@@ -34,10 +84,8 @@ nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to) {
3484
}
3585

3686
nvinfer1::Dims toDims(c10::IntArrayRef l) {
37-
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
38-
//TODO: Handle this with exceptions or whatever
39-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
40-
}
87+
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
88+
4189
nvinfer1::Dims dims;
4290
dims.nbDims = l.size();
4391
for (size_t i = 0; i < l.size(); i++) {
@@ -47,10 +95,8 @@ nvinfer1::Dims toDims(c10::IntArrayRef l) {
4795
}
4896

4997
nvinfer1::Dims toDims(c10::List<int64_t> l) {
50-
if (l.size() > nvinfer1::Dims::MAX_DIMS) {
51-
//TODO: Handle this with exceptions or whatever
52-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
53-
}
98+
TRTORCH_CHECK(l.size() <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
99+
54100
nvinfer1::Dims dims;
55101
dims.nbDims = l.size();
56102
for (size_t i = 0; i < l.size(); i++) {
@@ -65,10 +111,8 @@ nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to) {
65111
return toDims(l);
66112
}
67113

68-
if (pad_to > nvinfer1::Dims::MAX_DIMS) {
69-
//TODO: Handle this with exceptions or whatever
70-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
71-
}
114+
TRTORCH_CHECK(pad_to <= nvinfer1::Dims::MAX_DIMS, "The list requested to be converted to nvinfer1::Dims exceeds the max number of dimensions for TensorRT");
115+
72116

73117
nvinfer1::Dims dims;
74118
dims.nbDims = pad_to;
@@ -109,7 +153,7 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
109153
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos) {
110154
// acceptable range for pos is [0, d.nbDims]
111155
TRTORCH_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");
112-
156+
113157
nvinfer1::Dims dims;
114158

115159
int i = 0;
@@ -148,10 +192,8 @@ std::string toStr(nvinfer1::Dims d) {
148192

149193

150194
nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
151-
if (l.size() != 2) {
152-
//TODO: Handle this with exceptions or whatever
153-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
154-
}
195+
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");
196+
155197
nvinfer1::DimsHW dims;
156198
dims.nbDims = l.size();
157199
for (size_t i = 0; i < l.size(); i++) {
@@ -161,10 +203,8 @@ nvinfer1::DimsHW toDimsHW(c10::List<int64_t> l) {
161203
}
162204

163205
nvinfer1::DimsHW toDimsHW(c10::IntArrayRef l) {
164-
if (l.size() != 2) {
165-
//TODO: Handle this with exceptions or whatever
166-
LOG_INTERNAL_ERROR("The list requested to be converted to nvinfer1::DimsHW is not 2");
167-
}
206+
TRTORCH_CHECK(l.size() == 2, "The list requested to be converted to nvinfer1::DimsHW is not 2");
207+
168208
nvinfer1::DimsHW dims;
169209
dims.nbDims = l.size();
170210
for (size_t i = 0; i < l.size(); i++) {

Diff for: core/util/trt_util.h

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ namespace util {
7777

7878
int64_t volume(const nvinfer1::Dims& d);
7979

80+
bool broadcastable(nvinfer1::Dims a, nvinfer1::Dims b, bool multidirectional=true);
8081
nvinfer1::Dims toDimsPad(c10::IntArrayRef l, uint64_t pad_to);
8182
nvinfer1::Dims toDimsPad(c10::List<int64_t> l, uint64_t pad_to);
8283
nvinfer1::Dims unpadDims(const nvinfer1::Dims& d);

0 commit comments

Comments
 (0)