Skip to content

Commit fa227b0

Browse files
committed
feat(): added adaptive_avg_pool2d plugin, and added test for it
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 9458f21 commit fa227b0

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

Diff for: core/conversion/converters/impl/pooling.cpp

+42-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "core/util/prelude.h"
22
#include "core/conversion/converters/converters.h"
3+
#include "plugins/interpolate_plugin.h"
4+
35

46
namespace trtorch {
57
namespace core {
@@ -273,30 +275,51 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
273275
in_shape = util::toVec(in->getDimensions());
274276
}
275277

276-
auto out_shape = args[1].IValue()->toIntList();
278+
//auto out_size = args[1].IValue()->toIntList();
279+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
280+
281+
if (ctx->input_is_dynamic) {
282+
LOG_WARNING("Pooling layer will be run through ATen, not TensorRT. Performance may differ.");
277283

278-
std::vector<int64_t> stride(out_shape.size());
279-
for (size_t i = 0; i < out_shape.size(); i++) {
280-
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i];
281-
}
282-
LOG_DEBUG("Stride: " << util::toDims(stride));
284+
auto out_shape = in_shape;
285+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
283286

284-
std::vector<int64_t> window(out_shape.size());
285-
for (size_t i = 0; i < out_shape.size(); i++) {
286-
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
287-
}
287+
auto creator = new plugins::InterpolatePluginCreator();
288+
auto plugin = creator->createPlugin("adaptive_pool2d", in_shape, out_shape, out_size, std::string("adaptive_pool2d"), false);
288289

289-
LOG_DEBUG("Window: " << util::toDims(window));
290+
auto pooling_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
291+
TRTORCH_CHECK(pooling_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
290292

291-
auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
292-
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);
293+
pooling_layer->setName(util::node_info(n).c_str());
293294

294-
new_layer->setStrideNd(util::toDims(stride));
295+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], pooling_layer->getOutput(0));
295296

296-
new_layer->setName(util::node_info(n).c_str());
297-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
297+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
298+
} else {
299+
std::vector<int64_t> stride(out_size.size());
300+
for (size_t i = 0; i < out_size.size(); i++) {
301+
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_size[(out_size.size() - 1) - i];
302+
}
303+
LOG_DEBUG("Stride: " << util::toDims(stride));
298304

299-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
305+
std::vector<int64_t> window(out_size.size());
306+
for (size_t i = 0; i < out_size.size(); i++) {
307+
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_size[out_size.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
308+
}
309+
310+
LOG_DEBUG("Window: " << util::toDims(window));
311+
312+
auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
313+
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);
314+
315+
new_layer->setStrideNd(util::toDims(stride));
316+
317+
new_layer->setName(util::node_info(n).c_str());
318+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
319+
320+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
321+
}
322+
300323
return true;
301324
}
302325
});
@@ -306,3 +329,5 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
306329
} // namespace conversion
307330
} // namespace core
308331
} // trtorch
332+
333+

Diff for: tests/core/converters/test_pooling.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,29 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
383383

384384
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
385385
}
386+
387+
TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
388+
const auto graph = R"IR(
389+
graph(%0 : Tensor):
390+
%2 : int = prim::Constant[value=3]()
391+
%3 : int = prim::Constant[value=4]()
392+
%6 : int[] = prim::ListConstruct(%2, %3)
393+
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
394+
return (%10))IR";
395+
396+
auto g = std::make_shared<torch::jit::Graph>();
397+
torch::jit::parseIR(graph, &*g);
398+
399+
//PyTorch MaxPool needs a 3D input
400+
auto in = at::randint(-5, 5, {10, 18, 36}, at::kCUDA);
401+
402+
auto jit_in = at::clone(in);
403+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
404+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
405+
406+
auto trt_in = at::clone(in);
407+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
408+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
409+
410+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
411+
}

0 commit comments

Comments
 (0)