@@ -32,6 +32,97 @@ TEST(Converters, ATenMaxPool2DConvertsCorrectly) {
32
32
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
33
33
}
34
34
35
+ TEST (Converters, ATenAvgPool2DConvertsCorrectly) {
36
+ const auto graph = R"IR(
37
+ graph(%0 : Tensor):
38
+ %1 : int = prim::Constant[value=0]()
39
+ %2 : int = prim::Constant[value=1]()
40
+ %3 : int = prim::Constant[value=2]()
41
+ %4 : bool = prim::Constant[value=0]()
42
+ %5 : bool = prim::Constant[value=1]()
43
+ %6 : int[] = prim::ListConstruct(%1, %1)
44
+ %7 : int[] = prim::ListConstruct(%2, %2)
45
+ %8 : int[] = prim::ListConstruct(%3, %3)
46
+ %9 : None = prim::Constant()
47
+ %10 : Tensor = aten::avg_pool2d(%0, %8, %7, %6, %4, %5, %9)
48
+ return (%10))IR" ;
49
+
50
+ auto g = std::make_shared<torch::jit::Graph>();
51
+ torch::jit::parseIR (graph, &*g);
52
+
53
+ // PyTorch MaxPool needs a 3D input
54
+ auto in = at::randint (-5 , 5 , {1 , 4 , 4 }, at::kCUDA );
55
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
56
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
57
+
58
+ in = at::clone (in);
59
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
60
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
61
+
62
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
63
+ }
64
+
65
+
66
+ TEST (Converters, ATenAvgPool2DCeilConvertsCorrectly) {
67
+ const auto graph = R"IR(
68
+ graph(%0 : Tensor):
69
+ %1 : int = prim::Constant[value=0]()
70
+ %2 : int = prim::Constant[value=1]()
71
+ %3 : int = prim::Constant[value=2]()
72
+ %4 : bool = prim::Constant[value=0]()
73
+ %5 : bool = prim::Constant[value=1]()
74
+ %6 : int[] = prim::ListConstruct(%1, %1)
75
+ %7 : int[] = prim::ListConstruct(%2, %2)
76
+ %8 : int[] = prim::ListConstruct(%3, %3)
77
+ %9 : None = prim::Constant()
78
+ %10 : Tensor = aten::avg_pool2d(%0, %8, %7, %6, %5, %5, %9)
79
+ return (%10))IR" ;
80
+
81
+ auto g = std::make_shared<torch::jit::Graph>();
82
+ torch::jit::parseIR (graph, &*g);
83
+
84
+ // PyTorch MaxPool needs a 3D input
85
+ auto in = at::randint (-5 , 5 , {1 , 4 , 4 }, at::kCUDA );
86
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
87
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
88
+
89
+ in = at::clone (in);
90
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
91
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
92
+
93
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
94
+ }
95
+
96
+ TEST (Converters, ATenAvgPool2DNoCountPadConvertsCorrectly) {
97
+ const auto graph = R"IR(
98
+ graph(%0 : Tensor):
99
+ %1 : int = prim::Constant[value=0]()
100
+ %2 : int = prim::Constant[value=1]()
101
+ %3 : int = prim::Constant[value=2]()
102
+ %4 : bool = prim::Constant[value=0]()
103
+ %5 : bool = prim::Constant[value=1]()
104
+ %6 : int[] = prim::ListConstruct(%1, %1)
105
+ %7 : int[] = prim::ListConstruct(%2, %2)
106
+ %8 : int[] = prim::ListConstruct(%3, %3)
107
+ %9 : None = prim::Constant()
108
+ %10 : Tensor = aten::avg_pool2d(%0, %8, %7, %6, %4, %4, %9)
109
+ return (%10))IR" ;
110
+
111
+ auto g = std::make_shared<torch::jit::Graph>();
112
+ torch::jit::parseIR (graph, &*g);
113
+
114
+ // PyTorch MaxPool needs a 3D input
115
+ auto in = at::randint (-5 , 5 , {1 , 4 , 4 }, at::kCUDA );
116
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
117
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
118
+
119
+ in = at::clone (in);
120
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
121
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
122
+
123
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
124
+ }
125
+
35
126
TEST (Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
36
127
const auto graph = R"IR(
37
128
graph(%0 : Tensor):
0 commit comments