1
1
#include " torch/torch.h"
2
2
#include " core/util/prelude.h"
3
3
#include " core/conversion/converters/converters.h"
4
+ #include " NvInfer.h"
5
+ #include " plugins/interpolate_plugin.h"
4
6
5
7
#include < csignal>
6
8
@@ -108,7 +110,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
108
110
auto in = args[0 ].ITensor ();
109
111
auto in_shape = util::toVec (in->getDimensions ());
110
112
111
- bool align_corners = args[2 ].IValue ()-> to < bool > ();
113
+ bool align_corners = args[2 ].unwrapToBool ();
112
114
113
115
// Case 1: user uses output size and not scales
114
116
if (!args[1 ].IValue ()->isNone () && args[3 ].IValue ()->isNone ()) {
@@ -119,16 +121,29 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
119
121
auto out_shape = in_shape;
120
122
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
121
123
122
- auto resize_layer = ctx->net ->addResize (*in);
123
- TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
124
+ if (!align_corners) {
125
+ // auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
126
+ // auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
127
+ auto creator = new plugins::InterpolatePluginCreator ();
124
128
125
- resize_layer->setOutputDimensions (util::toDims (out_shape));
126
- resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
127
- resize_layer->setAlignCorners (align_corners);
128
- resize_layer->setName (util::node_info (n).c_str ());
129
+ auto plugin = creator->createPlugin (util::node_info (n).c_str (), in_shape, out_shape, out_size, std::string (" linear" ), align_corners);
129
130
130
- auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
131
- LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
131
+ auto resize_layer = ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(in), 1 , *plugin);
132
+
133
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
134
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
135
+ } else {
136
+ auto resize_layer = ctx->net ->addResize (*in);
137
+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
138
+
139
+ resize_layer->setOutputDimensions (util::toDims (out_shape));
140
+ resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
141
+ resize_layer->setAlignCorners (align_corners);
142
+ resize_layer->setName (util::node_info (n).c_str ());
143
+
144
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
145
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
146
+ }
132
147
} else {
133
148
TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_linear1d not supported yet." );
134
149
}
0 commit comments