Skip to content

Commit 141763f

Browse files
committed
feat(): addressed some PR comments, refactored code
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 52be580 commit 141763f

File tree

1 file changed

+18
-43
lines changed

1 file changed

+18
-43
lines changed

Diff for: core/conversion/converters/impl/interpolate.cpp

+18-43
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "torch/torch.h"
22
#include "core/util/prelude.h"
33
#include "core/conversion/converters/converters.h"
4-
#include "NvInfer.h"
54
#include "plugins/interpolate_plugin.h"
5+
#include "NvInfer.h"
66
#include "NvInferRuntimeCommon.h"
77

88
#include <tuple>
@@ -18,26 +18,13 @@ namespace {
1818
* Helper functions
1919
*/
2020

21-
auto parse_nearest(args args) {
22-
auto in = args[0].ITensor();
23-
auto in_shape = util::toVec(in->getDimensions());
24-
25-
return std::make_tuple(in, in_shape);
26-
}
27-
28-
auto parse_linear(args args) {
29-
auto in = args[0].ITensor();
30-
auto in_shape = util::toVec(in->getDimensions());
31-
bool align_corners = args[2].unwrapToBool();
32-
33-
return std::make_tuple(in, in_shape, align_corners);
34-
}
35-
3621
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
3722
std::vector<int64_t> in_shape,
3823
std::vector<int64_t> out_shape,
3924
std::vector<int64_t> out_size,
4025
std::string mode) {
26+
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may differ.");
27+
4128
auto creator = new plugins::InterpolatePluginCreator();
4229
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
4330

@@ -79,10 +66,8 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7966
.pattern({
8067
"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)",
8168
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
82-
auto parsed = parse_nearest(args);
83-
84-
auto in = std::get<0>(parsed);
85-
auto in_shape = std::get<1>(parsed);
69+
auto in = args[0].ITensor();
70+
auto in_shape = util::toVec(in->getDimensions());
8671

8772
// Case 1: user uses output size and not scales
8873
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
@@ -103,10 +88,8 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
10388
}).pattern({
10489
"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)",
10590
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
106-
auto parsed = parse_nearest(args);
107-
108-
auto in = std::get<0>(parsed);
109-
auto in_shape = std::get<1>(parsed);
91+
auto in = args[0].ITensor();
92+
auto in_shape = util::toVec(in->getDimensions());
11093

11194
// Case 1: user uses output_size and not scales_h, scales_w
11295
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone()){
@@ -127,10 +110,8 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
127110
}).pattern({
128111
"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
129112
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
130-
auto parsed = parse_nearest(args);
131-
132-
auto in = std::get<0>(parsed);
133-
auto in_shape = std::get<1>(parsed);
113+
auto in = args[0].ITensor();
114+
auto in_shape = util::toVec(in->getDimensions());
134115

135116
// Case 1: user uses output size and not scales_d, scales_h, scales_w
136117
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
@@ -151,11 +132,9 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
151132
}).pattern({
152133
"aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)",
153134
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
154-
auto parsed = parse_linear(args);
155-
156-
auto in = std::get<0>(parsed);
157-
auto in_shape = std::get<1>(parsed);
158-
auto align_corners = std::get<2>(parsed);
135+
auto in = args[0].ITensor();
136+
auto in_shape = util::toVec(in->getDimensions());
137+
bool align_corners = args[2].unwrapToBool();
159138

160139
// Case 1: user uses output size and not scales
161140
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
@@ -181,11 +160,9 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
181160
}).pattern({
182161
"aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)",
183162
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
184-
auto parsed = parse_linear(args);
185-
186-
auto in = std::get<0>(parsed);
187-
auto in_shape = std::get<1>(parsed);
188-
auto align_corners = std::get<2>(parsed);
163+
auto in = args[0].ITensor();
164+
auto in_shape = util::toVec(in->getDimensions());
165+
bool align_corners = args[2].unwrapToBool();
189166

190167
// Case 1: user uses output size and not scales_h, scales_w
191168
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
@@ -211,11 +188,9 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
211188
}).pattern({
212189
"aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
213190
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
214-
auto parsed = parse_linear(args);
215-
216-
auto in = std::get<0>(parsed);
217-
auto in_shape = std::get<1>(parsed);
218-
auto align_corners = std::get<2>(parsed);
191+
auto in = args[0].ITensor();
192+
auto in_shape = util::toVec(in->getDimensions());
193+
bool align_corners = args[2].unwrapToBool();
219194

220195
// Case 1: user uses output size and not scales_d, scales_h, scales_w
221196
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone() && args[5].IValue()->isNone()) {

0 commit comments

Comments
 (0)