@@ -436,6 +436,32 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
436
436
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
437
437
}
438
438
439
+ TEST (Converters, ATenAdaptiveAvgPool2DGlobalPoolingConvertsCorrectly) {
440
+ const auto graph = R"IR(
441
+ graph(%0 : Tensor):
442
+ %2 : int = prim::Constant[value=1]()
443
+ %3 : int = prim::Constant[value=1]()
444
+ %6 : int[] = prim::ListConstruct(%2, %3)
445
+ %10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
446
+ return (%10))IR" ;
447
+
448
+ auto g = std::make_shared<torch::jit::Graph>();
449
+ torch::jit::parseIR (graph, g.get ());
450
+
451
+ // PyTorch PyTorch adaptive_avg_pool2d needs a 4D input or a 3D input
452
+ auto in = at::randint (-5 , 5 , {64 , 16 , 32 , 32 }, at::kCUDA );
453
+
454
+ auto jit_in = at::clone (in);
455
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
456
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
457
+
458
+ auto trt_in = at::clone (in);
459
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
460
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
461
+
462
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
463
+ }
464
+
439
465
TEST (Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
440
466
const auto graph = R"IR(
441
467
graph(%0 : Tensor):
@@ -488,6 +514,32 @@ TEST(Converters, ATenAdaptiveAvgPool1DConvertsCorrectly) {
488
514
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 1.0 ));
489
515
}
490
516
517
+ TEST (Converters, ATenAdaptiveAvgPool1DGlobalPoolingConvertsCorrectly) {
518
+ const auto graph =
519
+ R"IR(
520
+ graph(%0 : Tensor):
521
+ %2 : int = prim::Constant[value=1]()
522
+ %6 : int[] = prim::ListConstruct(%2)
523
+ %10 : Tensor = aten::adaptive_avg_pool1d(%0, %6)
524
+ return (%10))IR" ;
525
+
526
+ auto g = std::make_shared<torch::jit::Graph>();
527
+ torch::jit::parseIR (graph, g.get ());
528
+
529
+ // PyTorch adaptive_avg_pool1d needs a 3D input or a 2D input
530
+ auto in = at::randint (-5 , 5 , {3 , 16 }, at::kCUDA );
531
+
532
+ auto jit_in = at::clone (in);
533
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
534
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
535
+
536
+ auto trt_in = at::clone (in);
537
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
538
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
539
+
540
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
541
+ }
542
+
491
543
TEST (Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) {
492
544
const auto graph = R"IR(
493
545
graph(%0 : Tensor):
0 commit comments