Skip to content

Commit 92e3818

Browse files
committed
feat(interpolate): Addressing the linear, scale factor, align corners edge case
This commit adds support in some cases for the edge case when handling torch.nn.functional.interpolate where the user is doing some form of linear upsampling and uses scale factor to calculate the new tensor size at runtime and they set align corners to true (as of PyTorch 1.5 this is no longer the default behavior). This commit adds support for this case when users chose to construct static input size engines via the interpolate plugin which will run the function from ATen on CPU. In the case of dynamic input shapes with these 3 conditions the compilation will terminate with an error. The ultimate solution will be to find the root cause of the descripancy between PyTorch and TensorRT. Barring that we will need to use the dimension calculation primatives for TensorRT plugins. However, there is a limitation where static values in the computation cannot be floats which PyTorch scale factors are. Therefore it doesn't seem possible currently to support this usecase. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0cda1cc commit 92e3818

File tree

6 files changed

+307
-198
lines changed

6 files changed

+307
-198
lines changed

Diff for: core/conversion/converters/impl/interpolate.cpp

+85-82
Large diffs are not rendered by default.

Diff for: core/conversion/converters/impl/plugins/interpolate_plugin.cpp

+99-19
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,31 @@ InterpolatePlugin::InterpolatePlugin(
1717
std::vector<int64_t> in_shape,
1818
std::vector<int64_t> out_shape,
1919
std::vector<int64_t> size,
20+
std::vector<double> scales,
2021
std::string mode,
21-
bool align_corners)
22-
: in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
22+
bool align_corners,
23+
bool use_scales)
24+
: in_shape_(in_shape), out_shape_(out_shape), size_(size), scales_(scales), mode_(mode), align_corners_(align_corners), use_scales_(use_scales) {
25+
if (use_scales) {
26+
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
27+
TRTORCH_ASSERT(scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
28+
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
29+
at::Tensor output;
30+
31+
if (mode_ == "linear") {
32+
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, scales_[0]);
33+
} else if (mode_ == "bilinear") {
34+
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
35+
std::cout << output.sizes() << std::endl;
36+
} else if (mode_ == "trilinear") {
37+
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
38+
}
39+
40+
out_shape_ = output.sizes().vec();
41+
} else {
42+
TRTORCH_ASSERT((size_.size() != 0 && out_shape_.size() != 0), "Attempted to use interpolate plugin without providing output size while use_scales=false");
43+
}
44+
}
2345

2446
InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
2547
std::istringstream data_stream(std::string(data, length));
@@ -42,6 +64,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
4264
input_archive.read("size", value);
4365
size_ = value.toIntVector();
4466
}
67+
{
68+
torch::IValue value;
69+
input_archive.read("scales", value);
70+
scales_ = value.toDoubleVector();
71+
}
4572
{
4673
torch::IValue value;
4774
input_archive.read("mode", value);
@@ -52,6 +79,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
5279
input_archive.read("align_corners", value);
5380
align_corners_ = value.toBool();
5481
}
82+
{
83+
torch::IValue value;
84+
input_archive.read("use_scales", value);
85+
use_scales_ = value.toBool();
86+
}
5587
}
5688

5789
std::vector<int64_t> InterpolatePlugin::getInputShape() {
@@ -83,7 +115,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
83115
}
84116

85117
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
86-
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
118+
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
87119
}
88120

89121
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
@@ -93,9 +125,27 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
93125
nvinfer1::IExprBuilder& exprBuilder) {
94126
nvinfer1::DimsExprs output(inputs[0]);
95127

96-
for (unsigned int i = 0; i < out_shape_.size(); i++) {
97-
output.d[i] = exprBuilder.constant(out_shape_[i]);
98-
}
128+
// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true to cover
129+
// the different implementations between PyTorch and TRT. However TRT currently does not support doubles
130+
// for ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor
131+
// if (use_scales_) {
132+
// auto input_dimsexprs = inputs[0];
133+
// output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
134+
// if (mode_ == "linear") {
135+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
136+
// } else if (mode_ == "bilinear") {
137+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
138+
// output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
139+
// } else if (mode_ == "trilinear") {
140+
// output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
141+
// output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
142+
// output.d[3] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
143+
// }
144+
// } else {
145+
for (unsigned int i = 0; i < out_shape_.size(); i++) {
146+
output.d[i] = exprBuilder.constant(out_shape_[i]);
147+
}
148+
//}
99149

100150
return output;
101151
}
@@ -131,8 +181,10 @@ std::string InterpolatePlugin::serializeToString() const {
131181
output_archive.write("in_shape", torch::IValue(in_shape_));
132182
output_archive.write("out_shape", torch::IValue(out_shape_));
133183
output_archive.write("size", torch::IValue(size_));
184+
output_archive.write("scales", torch::IValue(scales_));
134185
output_archive.write("mode", torch::IValue(mode_));
135186
output_archive.write("align_corners", torch::IValue(align_corners_));
187+
output_archive.write("use_scales", torch::IValue(use_scales_));
136188

137189
std::ostringstream data_str;
138190
output_archive.save_to(data_str);
@@ -201,14 +253,24 @@ int InterpolatePlugin::enqueue(
201253

202254
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
203255

204-
if (mode_ == "linear") {
205-
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
206-
} else if (mode_ == "bilinear") {
207-
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
208-
} else if (mode_ == "trilinear") {
209-
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
210-
} else if (mode_ == "adaptive_pool2d") {
211-
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
256+
if (use_scales_) {
257+
if (mode_ == "linear") {
258+
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
259+
} else if (mode_ == "bilinear") {
260+
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
261+
} else if (mode_ == "trilinear") {
262+
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
263+
}
264+
} else {
265+
if (mode_ == "linear") {
266+
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
267+
} else if (mode_ == "bilinear") {
268+
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
269+
} else if (mode_ == "trilinear") {
270+
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
271+
} else if (mode_ == "adaptive_pool2d") {
272+
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
273+
}
212274
}
213275

214276
cudaEvent_t torch_event;
@@ -234,11 +296,27 @@ int InterpolatePlugin::enqueue(
234296
stream);
235297
cudaStreamSynchronize(stream);
236298

237-
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
238299

300+
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
239301
at::Tensor output;
240-
if (mode_ == "adaptive_pool2d") {
241-
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
302+
if (use_scales_) {
303+
if (mode_ == "linear") {
304+
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
305+
} else if (mode_ == "bilinear") {
306+
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
307+
} else if (mode_ == "trilinear") {
308+
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
309+
}
310+
} else {
311+
if (mode_ == "linear") {
312+
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
313+
} else if (mode_ == "bilinear") {
314+
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
315+
} else if (mode_ == "trilinear") {
316+
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
317+
} else if (mode_ == "adaptive_pool2d") {
318+
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
319+
}
242320
}
243321

244322
cudaMemcpyAsync(
@@ -277,10 +355,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
277355
std::vector<int64_t> in_shape,
278356
std::vector<int64_t> out_shape,
279357
std::vector<int64_t> size,
358+
std::vector<double> scales,
280359
std::string mode,
281-
bool align_corners) {
360+
bool align_corners,
361+
bool use_scales) {
282362
name_ = name;
283-
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
363+
return new InterpolatePlugin(in_shape, out_shape, size, scales, mode, align_corners, use_scales);
284364
}
285365

286366
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(

Diff for: core/conversion/converters/impl/plugins/interpolate_plugin.h

+13-7
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
3131
std::vector<int64_t> in_shape_;
3232
std::vector<int64_t> out_shape_;
3333
std::vector<int64_t> size_;
34+
std::vector<double> scales_;
3435
std::string mode_;
3536
bool align_corners_;
37+
bool use_scales_;
3638

3739
protected:
3840
// To prevent compiler warnings
@@ -49,8 +51,10 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
4951
std::vector<int64_t> in_shape,
5052
std::vector<int64_t> out_shape,
5153
std::vector<int64_t> size,
54+
std::vector<double> scales,
5255
std::string mode,
53-
bool align_corners);
56+
bool align_corners,
57+
bool use_scales);
5458

5559
InterpolatePlugin(const char* data, size_t length);
5660

@@ -136,12 +140,14 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
136140
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
137141

138142
InterpolatePlugin* createPlugin(
139-
const char* name,
140-
std::vector<int64_t> in_shape,
141-
std::vector<int64_t> out_shape,
142-
std::vector<int64_t> size,
143-
std::string mode,
144-
bool align_corners);
143+
const char* name,
144+
std::vector<int64_t> in_shape,
145+
std::vector<int64_t> out_shape,
146+
std::vector<int64_t> size,
147+
std::vector<double> scales,
148+
std::string mode,
149+
bool align_corners,
150+
bool use_scales);
145151

146152
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
147153

Diff for: core/conversion/converters/impl/pooling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ auto pooling_registrations TRTORCH_UNUSED =
317317

318318
auto creator = new plugins::InterpolatePluginCreator();
319319
auto plugin = creator->createPlugin(
320-
"adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
320+
"adaptive_pool2d", in_shape, out_shape, out_size, {}, std::string("adaptive_pool2d"), false, false);
321321

322322
auto pooling_layer =
323323
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);

0 commit comments

Comments
 (0)