Skip to content

Commit 94d6a0f

Browse files
committed
fix(//core/conversion/converters): Fix plugin implementation for TRT 7
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent cff4211 commit 94d6a0f

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

Diff for: core/conversion/converters/impl/conv_deconv.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
4545

4646
deconv->setStrideNd(stride);
4747
deconv->setPaddingNd(padding);
48+
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR == 1)
4849
deconv->setDilationNd(dilation);
4950
deconv->setNbGroups(groups);
50-
51+
#endif
5152
new_layer = deconv;
5253
} else {
5354
nvinfer1::IConvolutionLayer* conv;

Diff for: core/conversion/converters/impl/interpolate.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
153153
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
154154
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
155155
} else {
156-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
156+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
157157
}
158158
#else
159159
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
@@ -185,7 +185,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
185185
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
186186
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
187187
} else {
188-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
188+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
189189
}
190190
#else
191191
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
@@ -217,7 +217,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
217217
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
218218
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
219219
} else {
220-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
220+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
221221
}
222222
#else
223223
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);

Diff for: core/conversion/converters/impl/plugins/interpolate_plugin.cpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
178178

179179
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
180180

181-
if (mode == "linear") {
181+
if (mode_ == "linear") {
182182
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
183-
} else if (mode == "bilinear") {
183+
} else if (mode_ == "bilinear") {
184184
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
185-
} else if (mode == "trilinear") {
185+
} else if (mode_ == "trilinear") {
186186
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
187-
} else if (mode == "adaptive_pool2d") {
187+
} else if (mode_ == "adaptive_pool2d") {
188188
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
189189
}
190190

@@ -212,11 +212,6 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
212212
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
213213
}
214214

215-
output = output.contiguous();
216-
for (int i = 0; i < util::volume(outputDesc->dims); i++) {
217-
std::cout << ((float*)output.data_ptr())[i] << std::endl;
218-
}
219-
220215
cudaMemcpyAsync(outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
221216
cudaStreamSynchronize(stream);
222217

0 commit comments

Comments
 (0)