@@ -17,9 +17,31 @@ InterpolatePlugin::InterpolatePlugin(
17
17
std::vector<int64_t > in_shape,
18
18
std::vector<int64_t > out_shape,
19
19
std::vector<int64_t > size,
20
+ std::vector<double > scales,
20
21
std::string mode,
21
- bool align_corners)
22
- : in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners) {}
22
+ bool align_corners,
23
+ bool use_scales)
24
+ : in_shape_(in_shape), out_shape_(out_shape), size_(size), scales_(scales), mode_(mode), align_corners_(align_corners), use_scales_(use_scales) {
25
+ if (use_scales) {
26
+ TRTORCH_ASSERT (mode_ != " adaptive_pool2d" , " use_scales is not valid for adaptive_pool2d" );
27
+ TRTORCH_ASSERT (scales_.size () != 0 , " Attempted to use interpolate plugin without providing scales while use_scales=true" );
28
+ at::Tensor input = at::randint (1 , 10 , in_shape, {at::kCUDA });
29
+ at::Tensor output;
30
+
31
+ if (mode_ == " linear" ) {
32
+ output = at::upsample_linear1d (input, c10::nullopt, align_corners_, scales_[0 ]);
33
+ } else if (mode_ == " bilinear" ) {
34
+ output = at::upsample_bilinear2d (input, c10::nullopt, align_corners_, scales_);
35
+ std::cout << output.sizes () << std::endl;
36
+ } else if (mode_ == " trilinear" ) {
37
+ output = at::upsample_trilinear3d (input, c10::nullopt, align_corners_, scales_);
38
+ }
39
+
40
+ out_shape_ = output.sizes ().vec ();
41
+ } else {
42
+ TRTORCH_ASSERT ((size_.size () != 0 && out_shape_.size () != 0 ), " Attempted to use interpolate plugin without providing output size while use_scales=false" );
43
+ }
44
+ }
23
45
24
46
InterpolatePlugin::InterpolatePlugin (const char * data, size_t length) {
25
47
std::istringstream data_stream (std::string (data, length));
@@ -42,6 +64,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
42
64
input_archive.read (" size" , value);
43
65
size_ = value.toIntVector ();
44
66
}
67
+ {
68
+ torch::IValue value;
69
+ input_archive.read (" scales" , value);
70
+ scales_ = value.toDoubleVector ();
71
+ }
45
72
{
46
73
torch::IValue value;
47
74
input_archive.read (" mode" , value);
@@ -52,6 +79,11 @@ InterpolatePlugin::InterpolatePlugin(const char* data, size_t length) {
52
79
input_archive.read (" align_corners" , value);
53
80
align_corners_ = value.toBool ();
54
81
}
82
+ {
83
+ torch::IValue value;
84
+ input_archive.read (" use_scales" , value);
85
+ use_scales_ = value.toBool ();
86
+ }
55
87
}
56
88
57
89
std::vector<int64_t > InterpolatePlugin::getInputShape () {
@@ -83,7 +115,7 @@ const char* InterpolatePlugin::getPluginNamespace() const {
83
115
}
84
116
85
117
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone () const {
86
- return new InterpolatePlugin (in_shape_, out_shape_, size_, mode_, align_corners_);
118
+ return new InterpolatePlugin (in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_ );
87
119
}
88
120
89
121
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions (
@@ -93,9 +125,27 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
93
125
nvinfer1::IExprBuilder& exprBuilder) {
94
126
nvinfer1::DimsExprs output (inputs[0 ]);
95
127
96
- for (unsigned int i = 0 ; i < out_shape_.size (); i++) {
97
- output.d [i] = exprBuilder.constant (out_shape_[i]);
98
- }
128
+ // TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true to cover
129
+ // the different implementations between PyTorch and TRT. However TRT currently does not support doubles
130
+ // for ExprBuilder constants. Once that is possible enable this code and remove the code in the constructor
131
+ // if (use_scales_) {
132
+ // auto input_dimsexprs = inputs[0];
133
+ // output.d[0] = exprBuilder.operation(DimensionOperation::kMAX, *input_dimsexprs.d[0], *exprBuilder.constant(0));
134
+ // if (mode_ == "linear") {
135
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
136
+ // } else if (mode_ == "bilinear") {
137
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
138
+ // output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
139
+ // } else if (mode_ == "trilinear") {
140
+ // output.d[1] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[1], *exprBuilder.constant(scales_[1]));
141
+ // output.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[2], *exprBuilder.constant(scales_[2]));
142
+ // output.d[3] = exprBuilder.operation(DimensionOperation::kPROD, *input_dimsexprs.d[3], *exprBuilder.constant(scales_[3]));
143
+ // }
144
+ // } else {
145
+ for (unsigned int i = 0 ; i < out_shape_.size (); i++) {
146
+ output.d [i] = exprBuilder.constant (out_shape_[i]);
147
+ }
148
+ // }
99
149
100
150
return output;
101
151
}
@@ -131,8 +181,10 @@ std::string InterpolatePlugin::serializeToString() const {
131
181
output_archive.write (" in_shape" , torch::IValue (in_shape_));
132
182
output_archive.write (" out_shape" , torch::IValue (out_shape_));
133
183
output_archive.write (" size" , torch::IValue (size_));
184
+ output_archive.write (" scales" , torch::IValue (scales_));
134
185
output_archive.write (" mode" , torch::IValue (mode_));
135
186
output_archive.write (" align_corners" , torch::IValue (align_corners_));
187
+ output_archive.write (" use_scales" , torch::IValue (use_scales_));
136
188
137
189
std::ostringstream data_str;
138
190
output_archive.save_to (data_str);
@@ -201,14 +253,24 @@ int InterpolatePlugin::enqueue(
201
253
202
254
cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
203
255
204
- if (mode_ == " linear" ) {
205
- at::upsample_linear1d_out (output, input, {size_[0 ]}, align_corners_);
206
- } else if (mode_ == " bilinear" ) {
207
- at::upsample_bilinear2d_out (output, input, {size_[0 ], size_[1 ]}, align_corners_);
208
- } else if (mode_ == " trilinear" ) {
209
- at::upsample_trilinear3d_out (output, input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
210
- } else if (mode_ == " adaptive_pool2d" ) {
211
- at::adaptive_avg_pool2d_out (output, input, {size_[0 ], size_[1 ]});
256
+ if (use_scales_) {
257
+ if (mode_ == " linear" ) {
258
+ at::upsample_linear1d_out (output, input, {}, align_corners_, scales_[0 ]);
259
+ } else if (mode_ == " bilinear" ) {
260
+ at::upsample_bilinear2d_out (output, input, {}, align_corners_, scales_[0 ], scales_[1 ]);
261
+ } else if (mode_ == " trilinear" ) {
262
+ at::upsample_trilinear3d_out (output, input, {}, align_corners_, scales_[0 ], scales_[1 ], scales_[2 ]);
263
+ }
264
+ } else {
265
+ if (mode_ == " linear" ) {
266
+ at::upsample_linear1d_out (output, input, {size_[0 ]}, align_corners_);
267
+ } else if (mode_ == " bilinear" ) {
268
+ at::upsample_bilinear2d_out (output, input, {size_[0 ], size_[1 ]}, align_corners_);
269
+ } else if (mode_ == " trilinear" ) {
270
+ at::upsample_trilinear3d_out (output, input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
271
+ } else if (mode_ == " adaptive_pool2d" ) {
272
+ at::adaptive_avg_pool2d_out (output, input, {size_[0 ], size_[1 ]});
273
+ }
212
274
}
213
275
214
276
cudaEvent_t torch_event;
@@ -234,11 +296,27 @@ int InterpolatePlugin::enqueue(
234
296
stream);
235
297
cudaStreamSynchronize (stream);
236
298
237
- at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
238
299
300
+ at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
239
301
at::Tensor output;
240
- if (mode_ == " adaptive_pool2d" ) {
241
- output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
302
+ if (use_scales_) {
303
+ if (mode_ == " linear" ) {
304
+ output = at::upsample_linear1d (input, c10::nullopt, align_corners_, {scales_[0 ]});
305
+ } else if (mode_ == " bilinear" ) {
306
+ output = at::upsample_bilinear2d (input, c10::nullopt, align_corners_, scales_);
307
+ } else if (mode_ == " trilinear" ) {
308
+ output = at::upsample_trilinear3d (input, c10::nullopt, align_corners_, scales_);
309
+ }
310
+ } else {
311
+ if (mode_ == " linear" ) {
312
+ output = at::upsample_linear1d (input, {size_[0 ]}, align_corners_);
313
+ } else if (mode_ == " bilinear" ) {
314
+ output = at::upsample_bilinear2d (input, {size_[0 ], size_[1 ]}, align_corners_);
315
+ } else if (mode_ == " trilinear" ) {
316
+ output = at::upsample_trilinear3d (input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
317
+ } else if (mode_ == " adaptive_pool2d" ) {
318
+ output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
319
+ }
242
320
}
243
321
244
322
cudaMemcpyAsync (
@@ -277,10 +355,12 @@ InterpolatePlugin* InterpolatePluginCreator::createPlugin(
277
355
std::vector<int64_t > in_shape,
278
356
std::vector<int64_t > out_shape,
279
357
std::vector<int64_t > size,
358
+ std::vector<double > scales,
280
359
std::string mode,
281
- bool align_corners) {
360
+ bool align_corners,
361
+ bool use_scales) {
282
362
name_ = name;
283
- return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
363
+ return new InterpolatePlugin (in_shape, out_shape, size, scales, mode, align_corners, use_scales );
284
364
}
285
365
286
366
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin (
0 commit comments