Skip to content

Commit cdf767c

Browse files
committed
Merge branch 'dynamic_interpolate' of https://github.com/uni19/TRTorch into uni19-dynamic_interpolate
2 parents 08b2455 + 0d65220 commit cdf767c

File tree

2 files changed

+386
-76
lines changed

2 files changed

+386
-76
lines changed

Diff for: core/conversion/converters/impl/interpolate.cpp

+152-64
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,51 @@ void resize_layer_size(
4646
const torch::jit::Node* n,
4747
nvinfer1::ITensor* in,
4848
std::vector<int64_t> out_shape,
49+
std::vector<float> scales,
4950
nvinfer1::ResizeMode mode,
5051
bool align_corners = false) {
52+
TRTORCH_CHECK((out_shape.size() > 0) ^ (scales.size() > 0), "only one of out_shape or scales should be defined");
5153
auto resize_layer = ctx->net->addResize(*in);
5254
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
5355

54-
resize_layer->setOutputDimensions(util::toDims(out_shape));
56+
if (out_shape.size() > 0) {
57+
auto th_dynamic_shape_mask = torch::zeros(out_shape.size(), torch::kInt32);
58+
auto th_static_shape_mask = torch::zeros(out_shape.size(), torch::kInt32);
59+
for (size_t idx = 0; idx < out_shape.size(); ++idx) {
60+
if (out_shape[idx] == -1) {
61+
th_dynamic_shape_mask[idx] = 1;
62+
} else {
63+
th_static_shape_mask[idx] = out_shape[idx];
64+
}
65+
}
66+
67+
auto dynamic_shape_mask = tensor_to_const(ctx, th_dynamic_shape_mask);
68+
auto static_shape_mask = tensor_to_const(ctx, th_static_shape_mask);
69+
auto input_shape = ctx->net->addShape(*in)->getOutput(0);
70+
auto dynamic_shape = ctx->net
71+
->addElementWise(*input_shape, *dynamic_shape_mask,
72+
nvinfer1::ElementWiseOperation::kPROD)
73+
->getOutput(0);
74+
auto target_output_shape = ctx->net
75+
->addElementWise(*dynamic_shape, *static_shape_mask,
76+
nvinfer1::ElementWiseOperation::kSUM)
77+
->getOutput(0);
78+
resize_layer->setInput(1, *target_output_shape);
79+
} else {
80+
resize_layer->setScales(scales.data(), scales.size());
81+
if (align_corners) {
82+
LOG_WARNING("interpolate with align_corners and scale_factor works differently in TensorRT and PyTorch.");
83+
}
84+
}
85+
5586
resize_layer->setResizeMode(mode);
5687
resize_layer->setName(util::node_info(n).c_str());
5788

5889
// if interpolation mode is linear, align corners must have been set to true.
5990
// else, don't use align corners.
6091
if (mode == nvinfer1::ResizeMode::kLINEAR) {
6192
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
93+
TRTORCH_CHECK(align_corners, "resize layer only support align_corner with TensorRT <= 7.0");
6294
resize_layer->setAlignCorners(true);
6395
#else
6496
resize_layer->setAlignCorners(align_corners);
@@ -77,26 +109,29 @@ void resize_layer_size(
77109
auto interpolate_registrations TRTORCH_UNUSED =
78110
RegisterNodeConversionPatterns()
79111
.pattern(
80-
{"aten::upsample_nearest1d.vec(Tensor self, int[] output_size, float? scales=None) -> (Tensor)",
112+
{"aten::upsample_nearest1d(Tensor self, int[] output_size, float? scales=None) -> (Tensor)",
81113
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
82114
auto in = args[0].ITensor();
83115
auto in_shape = util::toVec(in->getDimensions());
84116

85-
// Case 1: user uses output size and not scales
86-
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
117+
if (args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
118+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
119+
<< "\nOne of size or scale_factor should be defined");
120+
} else if (!args[2].IValue()->isNone()) {
121+
// Case 1: user uses scales
122+
float scale = args[2].IValue()->toDouble();
123+
std::vector<float> padded_scales(in_shape.size(), 1);
124+
padded_scales[padded_scales.size() - 1] = scale;
125+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
126+
} else {
127+
// Case 2: user uses output size
87128
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
88-
89129
TRTORCH_ASSERT(
90130
out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch");
91131

92132
auto out_shape = in_shape;
93133
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
94-
95-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
96-
} else {
97-
TRTORCH_THROW_ERROR(
98-
"Unable to convert node: "
99-
<< util::node_info(n) << "\nScale factor parameter for upsample_nearest1d not supported yet.");
134+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
100135
}
101136

102137
return true;
@@ -107,21 +142,26 @@ auto interpolate_registrations TRTORCH_UNUSED =
107142
auto in = args[0].ITensor();
108143
auto in_shape = util::toVec(in->getDimensions());
109144

110-
// Case 1: user uses output_size and not scales_h, scales_w
111-
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone()) {
145+
if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone())) {
146+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
147+
<< "\nOne of size or scale_factor should be defined");
148+
} else if (!args[2].IValue()->isNone() && !args[3].IValue()->isNone()) {
149+
// Case 1: user uses scales
150+
float scale_h = args[2].IValue()->toDouble();
151+
float scale_w = args[3].IValue()->toDouble();
152+
std::vector<float> padded_scales(in_shape.size(), 1);
153+
padded_scales[padded_scales.size() - 2] = scale_h;
154+
padded_scales[padded_scales.size() - 1] = scale_w;
155+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
156+
} else {
157+
// Case 2: user uses output size
112158
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
113-
114159
TRTORCH_ASSERT(
115160
out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch");
116161

117162
auto out_shape = in_shape;
118163
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
119-
120-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
121-
} else {
122-
TRTORCH_THROW_ERROR(
123-
"Unable to convert node: "
124-
<< util::node_info(n) << "\nScale factor parameter for upsample_nearest2d not supported yet.");
164+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
125165
}
126166

127167
return true;
@@ -132,59 +172,77 @@ auto interpolate_registrations TRTORCH_UNUSED =
132172
auto in = args[0].ITensor();
133173
auto in_shape = util::toVec(in->getDimensions());
134174

135-
// Case 1: user uses output size and not scales_d, scales_h,
136-
// scales_w
137-
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone() &&
138-
args[4].IValue()->isNone()) {
175+
if (args[1].IValue()->isNone() && (args[2].IValue()->isNone() || args[3].IValue()->isNone() ||
176+
args[4].IValue()->isNone())) {
177+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
178+
<< "\nOne of size or scale_factor should be defined");
179+
} else if (!args[2].IValue()->isNone() && !args[3].IValue()->isNone() && !args[4].IValue()->isNone()) {
180+
// Case 1: user uses scales
181+
float scale_d = args[2].IValue()->toDouble();
182+
float scale_h = args[3].IValue()->toDouble();
183+
float scale_w = args[4].IValue()->toDouble();
184+
std::vector<float> padded_scales(in_shape.size(), 1);
185+
padded_scales[padded_scales.size() - 3] = scale_d;
186+
padded_scales[padded_scales.size() - 2] = scale_h;
187+
padded_scales[padded_scales.size() - 1] = scale_w;
188+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
189+
} else {
190+
// Case 2: user uses output size
139191
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
140-
141192
TRTORCH_ASSERT(
142193
out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch");
143194

144195
auto out_shape = in_shape;
145196
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
146-
147-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
148-
} else {
149-
TRTORCH_THROW_ERROR(
150-
"Unable to convert node: "
151-
<< util::node_info(n) << "\nScale factor parameter for upsample_nearest3d not supported yet.");
152-
}
197+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
198+
}
153199

154200
return true;
155201
}})
156202
.pattern(
157-
{"aten::upsample_linear1d.vec(Tensor self, int[] output_size, bool align_corners, float[]? scales) -> (Tensor)",
203+
{"aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners, float? scales) -> (Tensor)",
158204
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
159205
auto in = args[0].ITensor();
160206
auto in_shape = util::toVec(in->getDimensions());
161207
bool align_corners = args[2].unwrapToBool();
162208

163-
// Case 1: user uses output size and not scales
164-
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
209+
if (args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
210+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
211+
<< "\nOne of size or scale_factor should be defined");
212+
} else if (!args[3].IValue()->isNone()) {
213+
// Case 1: user uses scales
214+
float scale = args[3].IValue()->toDouble();
215+
std::vector<float> padded_scales(in_shape.size(), 1);
216+
padded_scales[padded_scales.size() - 1] = scale;
217+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
218+
if (!align_corners) {
219+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
220+
<< "\nupsample_linear1d only supports align_corner with TensorRT <= 7.0.");
221+
} else {
222+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true);
223+
}
224+
#else
225+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
226+
#endif
227+
} else {
228+
// Case 2: user uses output size
165229
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
166-
167230
TRTORCH_ASSERT(
168231
out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");
169232

170233
auto out_shape = in_shape;
171234
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
172-
173235
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
174236
if (!align_corners) {
175237
// align_corners not supported in TensorRT, create plugin and
176238
// run layer through PyTorch
177239
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
178240
} else {
179-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
241+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true);
180242
}
181243
#else
182-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
244+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
183245
#endif
184-
} else {
185-
TRTORCH_THROW_ERROR(
186-
"Unable to convert node: " << util::node_info(n)
187-
<< "\nScale factor parameter for upsample_linear1d not supported yet.");
188246
}
189247

190248
return true;
@@ -196,8 +254,28 @@ auto interpolate_registrations TRTORCH_UNUSED =
196254
auto in_shape = util::toVec(in->getDimensions());
197255
bool align_corners = args[2].unwrapToBool();
198256

199-
// Case 1: user uses output size and not scales_h, scales_w
200-
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
257+
if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone())) {
258+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
259+
<< "\nOne of size or scale_factor should be defined");
260+
} else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone()) {
261+
// Case 1: user uses scales
262+
float scale_h = args[3].IValue()->toDouble();
263+
float scale_w = args[4].IValue()->toDouble();
264+
std::vector<float> padded_scales(in_shape.size(), 1);
265+
padded_scales[padded_scales.size() - 2] = scale_h;
266+
padded_scales[padded_scales.size() - 1] = scale_w;
267+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
268+
if (!align_corners) {
269+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
270+
<< "\nupsample_linear2d only supports align_corner with TensorRT <= 7.0.");
271+
} else {
272+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true);
273+
}
274+
#else
275+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
276+
#endif
277+
} else {
278+
// Case 2: user uses output size
201279
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
202280

203281
TRTORCH_ASSERT(
@@ -212,15 +290,11 @@ auto interpolate_registrations TRTORCH_UNUSED =
212290
// run layer through PyTorch
213291
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
214292
} else {
215-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
293+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true);
216294
}
217295
#else
218-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
296+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
219297
#endif
220-
} else {
221-
TRTORCH_THROW_ERROR(
222-
"Unable to convert node: "
223-
<< util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
224298
}
225299

226300
return true;
@@ -232,35 +306,49 @@ auto interpolate_registrations TRTORCH_UNUSED =
232306
auto in_shape = util::toVec(in->getDimensions());
233307
bool align_corners = args[2].unwrapToBool();
234308

235-
// Case 1: user uses output size and not scales_d, scales_h,
236-
// scales_w
237-
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone() &&
238-
args[5].IValue()->isNone()) {
309+
if (args[1].IValue()->isNone() && (args[3].IValue()->isNone() || args[4].IValue()->isNone() || args[5].IValue()->isNone())) {
310+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
311+
<< "\nOne of size or scale_factor should be defined");
312+
} else if (!args[3].IValue()->isNone() && !args[4].IValue()->isNone() && !args[5].IValue()->isNone()) {
313+
// Case 1: user uses scales
314+
float scale_d = args[3].IValue()->toDouble();
315+
float scale_h = args[4].IValue()->toDouble();
316+
float scale_w = args[5].IValue()->toDouble();
317+
std::vector<float> padded_scales(in_shape.size(), 1);
318+
padded_scales[padded_scales.size() - 3] = scale_d;
319+
padded_scales[padded_scales.size() - 2] = scale_h;
320+
padded_scales[padded_scales.size() - 1] = scale_w;
321+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
322+
if (!align_corners) {
323+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n)
324+
<< "\nupsample_linear3d only supports align_corner with TensorRT <= 7.0.");
325+
} else {
326+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, true);
327+
}
328+
#else
329+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
330+
#endif
331+
} else {
332+
// Case 2: user uses output size
239333
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
240-
241334
TRTORCH_ASSERT(
242335
out_size.size() == 3,
243336
"aten::upsample_trilinear3d input Tensor and output size dimension mismatch");
244337

245338
auto out_shape = in_shape;
246339
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
247-
248340
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
249341
if (!align_corners) {
250342
// align_corners not supported in TensorRT, create plugin and
251343
// run layer through PyTorch
252344
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
253345
} else {
254-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
346+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, true);
255347
}
256348
#else
257-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
349+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
258350
#endif
259-
} else {
260-
TRTORCH_THROW_ERROR(
261-
"Unable to convert node: "
262-
<< util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
263-
}
351+
}
264352

265353
return true;
266354
}});

0 commit comments

Comments
 (0)