1
1
#include " core/util/prelude.h"
2
2
#include " core/conversion/converters/converters.h"
3
+ #include " plugins/interpolate_plugin.h"
4
+
3
5
4
6
namespace trtorch {
5
7
namespace core {
@@ -273,30 +275,51 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
273
275
in_shape = util::toVec (in->getDimensions ());
274
276
}
275
277
276
- auto out_shape = args[1 ].IValue ()->toIntList ();
278
+ // auto out_size = args[1].IValue()->toIntList();
279
+ auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
280
+
281
+ if (ctx->input_is_dynamic ) {
282
+ LOG_WARNING (" Pooling layer will be run through ATen, not TensorRT. Performance may differ." );
277
283
278
- std::vector<int64_t > stride (out_shape.size ());
279
- for (size_t i = 0 ; i < out_shape.size (); i++) {
280
- stride[(stride.size () - 1 ) - i] = in_shape[(in_shape.size () - 1 ) - i] / out_shape[(out_shape.size () - 1 ) - i];
281
- }
282
- LOG_DEBUG (" Stride: " << util::toDims (stride));
284
+ auto out_shape = in_shape;
285
+ std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
283
286
284
- std::vector<int64_t > window (out_shape.size ());
285
- for (size_t i = 0 ; i < out_shape.size (); i++) {
286
- window[window.size () - 1 - i] = in_shape[in_shape.size () - 1 - i] - (out_shape[out_shape.size () - 1 - i] - 1 ) * stride[stride.size () - 1 - i];
287
- }
287
+ auto creator = new plugins::InterpolatePluginCreator ();
288
+ auto plugin = creator->createPlugin (" adaptive_pool2d" , in_shape, out_shape, out_size, std::string (" adaptive_pool2d" ), false );
288
289
289
- LOG_DEBUG (" Window: " << util::toDims (window));
290
+ auto pooling_layer = ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(&in), 1 , *plugin);
291
+ TRTORCH_CHECK (pooling_layer, " Unable to create pooling (interpolation) plugin from node" << *n);
290
292
291
- auto new_layer = ctx->net ->addPoolingNd (*in, nvinfer1::PoolingType::kAVERAGE , util::toDims (window));
292
- TRTORCH_CHECK (new_layer, " Unable to create average pooling layer from node: " << *n);
293
+ pooling_layer->setName (util::node_info (n).c_str ());
293
294
294
- new_layer-> setStrideNd ( util::toDims (stride ));
295
+ auto layer_output = ctx-> AssociateValueAndTensor (n-> outputs ()[ 0 ], pooling_layer-> getOutput ( 0 ));
295
296
296
- new_layer->setName (util::node_info (n).c_str ());
297
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
297
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
298
+ } else {
299
+ std::vector<int64_t > stride (out_size.size ());
300
+ for (size_t i = 0 ; i < out_size.size (); i++) {
301
+ stride[(stride.size () - 1 ) - i] = in_shape[(in_shape.size () - 1 ) - i] / out_size[(out_size.size () - 1 ) - i];
302
+ }
303
+ LOG_DEBUG (" Stride: " << util::toDims (stride));
298
304
299
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
305
+ std::vector<int64_t > window (out_size.size ());
306
+ for (size_t i = 0 ; i < out_size.size (); i++) {
307
+ window[window.size () - 1 - i] = in_shape[in_shape.size () - 1 - i] - (out_size[out_size.size () - 1 - i] - 1 ) * stride[stride.size () - 1 - i];
308
+ }
309
+
310
+ LOG_DEBUG (" Window: " << util::toDims (window));
311
+
312
+ auto new_layer = ctx->net ->addPoolingNd (*in, nvinfer1::PoolingType::kAVERAGE , util::toDims (window));
313
+ TRTORCH_CHECK (new_layer, " Unable to create average pooling layer from node: " << *n);
314
+
315
+ new_layer->setStrideNd (util::toDims (stride));
316
+
317
+ new_layer->setName (util::node_info (n).c_str ());
318
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
319
+
320
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
321
+ }
322
+
300
323
return true ;
301
324
}
302
325
});
@@ -306,3 +329,5 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
306
329
} // namespace conversion
307
330
} // namespace core
308
331
} // trtorch
332
+
333
+
0 commit comments