Skip to content

Commit f6f5e3e

Browse files
committed
fix: Update "reduceAxes" variable in GlobalPoolingConverter function and add corresponding uTests
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 7191959 commit f6f5e3e

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

Diff for: core/conversion/converters/impl/pooling.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ bool GlobalPoolingConverter(
1616
nvinfer1::PoolingType pool_type) {
1717
auto in = args[0].ITensorOrFreeze(ctx);
1818
nvinfer1::Dims dims = in->getDimensions();
19+
auto out_size = util::toDims(args[1].unwrapToIntList());
1920
// Generate a bitmask of all 1s except the last 2 bits (N and C axes)
20-
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) & ~0b11;
21+
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) ^ ((1 << (dims.nbDims - out_size.nbDims)) - 1);
2122
auto* new_layer = ctx->net->addReduce(
2223
*in,
2324
pool_type == nvinfer1::PoolingType::kMAX ? nvinfer1::ReduceOperation::kMAX : nvinfer1::ReduceOperation::kAVG,

Diff for: tests/core/conversion/converters/test_pooling.cpp

+52
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,32 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
436436
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
437437
}
438438

439+
TEST(Converters, ATenAdaptiveAvgPool2DGlobalPoolingConvertsCorrectly) {
440+
const auto graph = R"IR(
441+
graph(%0 : Tensor):
442+
%2 : int = prim::Constant[value=1]()
443+
%3 : int = prim::Constant[value=1]()
444+
%6 : int[] = prim::ListConstruct(%2, %3)
445+
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
446+
return (%10))IR";
447+
448+
auto g = std::make_shared<torch::jit::Graph>();
449+
torch::jit::parseIR(graph, g.get());
450+
451+
// PyTorch PyTorch adaptive_avg_pool2d needs a 4D input or a 3D input
452+
auto in = at::randint(-5, 5, {64, 16, 32, 32}, at::kCUDA);
453+
454+
auto jit_in = at::clone(in);
455+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
456+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
457+
458+
auto trt_in = at::clone(in);
459+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
460+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
461+
462+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
463+
}
464+
439465
TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
440466
const auto graph = R"IR(
441467
graph(%0 : Tensor):
@@ -488,6 +514,32 @@ TEST(Converters, ATenAdaptiveAvgPool1DConvertsCorrectly) {
488514
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1.0));
489515
}
490516

517+
TEST(Converters, ATenAdaptiveAvgPool1DGlobalPoolingConvertsCorrectly) {
518+
const auto graph =
519+
R"IR(
520+
graph(%0 : Tensor):
521+
%2 : int = prim::Constant[value=1]()
522+
%6 : int[] = prim::ListConstruct(%2)
523+
%10 : Tensor = aten::adaptive_avg_pool1d(%0, %6)
524+
return (%10))IR";
525+
526+
auto g = std::make_shared<torch::jit::Graph>();
527+
torch::jit::parseIR(graph, g.get());
528+
529+
// PyTorch adaptive_avg_pool1d needs a 3D input or a 2D input
530+
auto in = at::randint(-5, 5, {3, 16}, at::kCUDA);
531+
532+
auto jit_in = at::clone(in);
533+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
534+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
535+
536+
auto trt_in = at::clone(in);
537+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
538+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
539+
540+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
541+
}
542+
491543
TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) {
492544
const auto graph = R"IR(
493545
graph(%0 : Tensor):

0 commit comments

Comments
 (0)