Skip to content

Commit fb07513

Browse files
authored
fix: update tests of pooling converters (#2613)
1 parent ffbcc7a commit fb07513

File tree

1 file changed

+44
-69
lines changed

1 file changed

+44
-69
lines changed

tests/py/dynamo/conversion/test_pool_aten.py

+44-69
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ class TestPoolConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
1212
(3, 1, 0),
13-
(3, 1, 1),
14-
(2, None, 0),
15-
(4, 1, 1),
16-
(5, 2, 0),
17-
(7, 2, 1),
13+
((3,), (1,), (1,)),
14+
((2,), [], (0,)),
15+
((4,), (1,), (1,)),
16+
((5,), (2,), (0,)),
17+
((7,), (2,), (1,)),
1818
]
1919
)
2020
def test_avg_pool1d(
@@ -26,14 +26,10 @@ def test_avg_pool1d(
2626
count_include_pad=True,
2727
):
2828
class TestModule(torch.nn.Module):
29-
def __init__(self):
30-
super().__init__()
31-
self.pool = torch.nn.AvgPool1d(
32-
kernel_size, stride, padding, ceil_mode, count_include_pad
33-
)
34-
3529
def forward(self, x):
36-
return self.pool(x)
30+
return torch.ops.aten.avg_pool1d.default(
31+
x, kernel_size, stride, padding, ceil_mode, count_include_pad
32+
)
3733

3834
inputs = [torch.randn(1, 3, 32)]
3935
self.run_test(
@@ -46,7 +42,7 @@ def forward(self, x):
4642
[
4743
(3, 1, 0),
4844
(3, 1, 1),
49-
((2, 2), None, (1, 0)),
45+
((2, 2), [], (1, 0)),
5046
((4, 3), (1, 1), (1, 1)),
5147
((5, 4), (2, 1), (1, 0)),
5248
((7, 7), (1, 2), (0, 1)),
@@ -62,9 +58,9 @@ def test_avg_pool2d(
6258
divisor_override=None,
6359
):
6460
class TestModule(torch.nn.Module):
65-
def __init__(self):
66-
super().__init__()
67-
self.pool = torch.nn.AvgPool2d(
61+
def forward(self, x):
62+
return torch.ops.aten.avg_pool2d.default(
63+
x,
6864
kernel_size,
6965
stride,
7066
padding,
@@ -73,17 +69,14 @@ def __init__(self):
7369
divisor_override,
7470
)
7571

76-
def forward(self, x):
77-
return self.pool(x)
78-
7972
inputs = [torch.randn(1, 3, 32, 32)]
8073
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)
8174

8275
@parameterized.expand(
8376
[
8477
(3, 1, 0),
8578
(3, 1, 1),
86-
((2, 2, 3), None, (1, 0, 1)),
79+
((2, 2, 3), [], (1, 0, 1)),
8780
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
8881
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
8982
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
@@ -99,9 +92,9 @@ def test_avg_pool3d(
9992
divisor_override=None,
10093
):
10194
class TestModule(torch.nn.Module):
102-
def __init__(self):
103-
super().__init__()
104-
self.pool = torch.nn.AvgPool3d(
95+
def forward(self, x):
96+
return torch.ops.aten.avg_pool3d.default(
97+
x,
10598
kernel_size,
10699
stride,
107100
padding,
@@ -110,20 +103,17 @@ def __init__(self):
110103
divisor_override,
111104
)
112105

113-
def forward(self, x):
114-
return self.pool(x)
115-
116106
inputs = [torch.randn(1, 3, 32, 32, 32)]
117107
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)
118108

119109
@parameterized.expand(
120110
[
121111
(3, 1, 0),
122-
(3, 1, 1),
123-
(2, None, 0),
124-
(4, 1, 1),
125-
(5, 2, 0),
126-
(7, 2, 1),
112+
((3,), (1,), (1,)),
113+
((2,), [], (0,)),
114+
((4,), (1,), (1,)),
115+
((5,), (2,), (0,)),
116+
((7,), (2,), (1,)),
127117
]
128118
)
129119
def test_max_pool1d(
@@ -132,18 +122,13 @@ def test_max_pool1d(
132122
stride,
133123
padding,
134124
dilation=1,
135-
return_indices=False,
136125
ceil_mode=False,
137126
):
138127
class TestModule(torch.nn.Module):
139-
def __init__(self):
140-
super().__init__()
141-
self.pool = torch.nn.MaxPool1d(
142-
kernel_size, stride, padding, dilation, return_indices, ceil_mode
143-
)
144-
145128
def forward(self, x):
146-
return self.pool(x)
129+
return torch.ops.aten.max_pool1d.default(
130+
x, kernel_size, stride, padding, dilation, ceil_mode
131+
)
147132

148133
inputs = [torch.randn(1, 3, 32)]
149134
self.run_test(
@@ -157,7 +142,7 @@ def forward(self, x):
157142
[
158143
(3, 1, 0),
159144
(3, 1, 1),
160-
((2, 2), None, (1, 0)),
145+
((2, 2), [], (1, 0)),
161146
((4, 3), (1, 1), (1, 1)),
162147
((5, 4), (2, 1), (1, 0)),
163148
((7, 7), (1, 2), (0, 1)),
@@ -169,32 +154,27 @@ def test_max_pool2d(
169154
stride,
170155
padding,
171156
dilation=1,
172-
return_indices=False,
173157
ceil_mode=False,
174158
):
175159
class TestModule(torch.nn.Module):
176-
def __init__(self):
177-
super().__init__()
178-
self.pool = torch.nn.MaxPool2d(
179-
kernel_size,
180-
stride,
181-
padding,
182-
dilation,
183-
return_indices,
184-
ceil_mode,
185-
)
186-
187160
def forward(self, x):
188-
return self.pool(x)
161+
return torch.ops.aten.max_pool2d.default(
162+
x, kernel_size, stride, padding, dilation, ceil_mode
163+
)
189164

190165
inputs = [torch.randn(1, 3, 32, 32)]
191-
self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True)
166+
self.run_test(
167+
TestModule(),
168+
inputs,
169+
use_dynamo_tracer=True,
170+
enable_passes=True,
171+
)
192172

193173
@parameterized.expand(
194174
[
195175
(3, 1, 0),
196176
(3, 1, 1),
197-
((2, 2, 3), None, (1, 0, 1)),
177+
((2, 2, 3), [], (1, 0, 1)),
198178
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
199179
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
200180
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
@@ -206,26 +186,21 @@ def test_max_pool3d(
206186
stride,
207187
padding,
208188
dilation=1,
209-
return_indices=False,
210189
ceil_mode=False,
211190
):
212191
class TestModule(torch.nn.Module):
213-
def __init__(self):
214-
super().__init__()
215-
self.pool = torch.nn.MaxPool3d(
216-
kernel_size,
217-
stride,
218-
padding,
219-
dilation,
220-
return_indices,
221-
ceil_mode,
222-
)
223-
224192
def forward(self, x):
225-
return self.pool(x)
193+
return torch.ops.aten.max_pool3d.default(
194+
x, kernel_size, stride, padding, dilation, ceil_mode
195+
)
226196

227197
inputs = [torch.randn(1, 3, 32, 32, 32)]
228-
self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True)
198+
self.run_test(
199+
TestModule(),
200+
inputs,
201+
use_dynamo_tracer=True,
202+
enable_passes=True,
203+
)
229204

230205

231206
if __name__ == "__main__":

0 commit comments

Comments
 (0)