@@ -10,11 +10,11 @@ class TestPoolConverter(DispatchTestCase):
10
10
@parameterized .expand (
11
11
[
12
12
(3 , 1 , 0 ),
13
- (3 , 1 , 1 ),
14
- (2 , None , 0 ),
15
- (4 , 1 , 1 ),
16
- (5 , 2 , 0 ),
17
- (7 , 2 , 1 ),
13
+ (( 3 ,), ( 1 ,), ( 1 ,) ),
14
+ (( 2 ,), [], ( 0 ,) ),
15
+ (( 4 ,), ( 1 ,), ( 1 ,) ),
16
+ (( 5 ,), ( 2 ,), ( 0 ,) ),
17
+ (( 7 ,), ( 2 ,), ( 1 ,) ),
18
18
]
19
19
)
20
20
def test_avg_pool1d (
@@ -26,14 +26,10 @@ def test_avg_pool1d(
26
26
count_include_pad = True ,
27
27
):
28
28
class TestModule (torch .nn .Module ):
29
- def __init__ (self ):
30
- super ().__init__ ()
31
- self .pool = torch .nn .AvgPool1d (
32
- kernel_size , stride , padding , ceil_mode , count_include_pad
33
- )
34
-
35
29
def forward (self , x ):
36
- return self .pool (x )
30
+ return torch .ops .aten .avg_pool1d .default (
31
+ x , kernel_size , stride , padding , ceil_mode , count_include_pad
32
+ )
37
33
38
34
inputs = [torch .randn (1 , 3 , 32 )]
39
35
self .run_test (
@@ -46,7 +42,7 @@ def forward(self, x):
46
42
[
47
43
(3 , 1 , 0 ),
48
44
(3 , 1 , 1 ),
49
- ((2 , 2 ), None , (1 , 0 )),
45
+ ((2 , 2 ), [] , (1 , 0 )),
50
46
((4 , 3 ), (1 , 1 ), (1 , 1 )),
51
47
((5 , 4 ), (2 , 1 ), (1 , 0 )),
52
48
((7 , 7 ), (1 , 2 ), (0 , 1 )),
@@ -62,9 +58,9 @@ def test_avg_pool2d(
62
58
divisor_override = None ,
63
59
):
64
60
class TestModule (torch .nn .Module ):
65
- def __init__ (self ):
66
- super (). __init__ ()
67
- self . pool = torch . nn . AvgPool2d (
61
+ def forward (self , x ):
62
+ return torch . ops . aten . avg_pool2d . default (
63
+ x ,
68
64
kernel_size ,
69
65
stride ,
70
66
padding ,
@@ -73,17 +69,14 @@ def __init__(self):
73
69
divisor_override ,
74
70
)
75
71
76
- def forward (self , x ):
77
- return self .pool (x )
78
-
79
72
inputs = [torch .randn (1 , 3 , 32 , 32 )]
80
73
self .run_test (TestModule (), inputs , use_dynamo_tracer = True )
81
74
82
75
@parameterized .expand (
83
76
[
84
77
(3 , 1 , 0 ),
85
78
(3 , 1 , 1 ),
86
- ((2 , 2 , 3 ), None , (1 , 0 , 1 )),
79
+ ((2 , 2 , 3 ), [] , (1 , 0 , 1 )),
87
80
((4 , 3 , 2 ), (1 , 1 , 1 ), (1 , 1 , 0 )),
88
81
((5 , 4 , 3 ), (2 , 1 , 2 ), (1 , 0 , 1 )),
89
82
((7 , 7 , 7 ), (1 , 2 , 1 ), (0 , 1 , 1 )),
@@ -99,9 +92,9 @@ def test_avg_pool3d(
99
92
divisor_override = None ,
100
93
):
101
94
class TestModule (torch .nn .Module ):
102
- def __init__ (self ):
103
- super (). __init__ ()
104
- self . pool = torch . nn . AvgPool3d (
95
+ def forward (self , x ):
96
+ return torch . ops . aten . avg_pool3d . default (
97
+ x ,
105
98
kernel_size ,
106
99
stride ,
107
100
padding ,
@@ -110,20 +103,17 @@ def __init__(self):
110
103
divisor_override ,
111
104
)
112
105
113
- def forward (self , x ):
114
- return self .pool (x )
115
-
116
106
inputs = [torch .randn (1 , 3 , 32 , 32 , 32 )]
117
107
self .run_test (TestModule (), inputs , use_dynamo_tracer = True )
118
108
119
109
@parameterized .expand (
120
110
[
121
111
(3 , 1 , 0 ),
122
- (3 , 1 , 1 ),
123
- (2 , None , 0 ),
124
- (4 , 1 , 1 ),
125
- (5 , 2 , 0 ),
126
- (7 , 2 , 1 ),
112
+ (( 3 ,), ( 1 ,), ( 1 ,) ),
113
+ (( 2 ,), [], ( 0 ,) ),
114
+ (( 4 ,), ( 1 ,), ( 1 ,) ),
115
+ (( 5 ,), ( 2 ,), ( 0 ,) ),
116
+ (( 7 ,), ( 2 ,), ( 1 ,) ),
127
117
]
128
118
)
129
119
def test_max_pool1d (
@@ -132,18 +122,13 @@ def test_max_pool1d(
132
122
stride ,
133
123
padding ,
134
124
dilation = 1 ,
135
- return_indices = False ,
136
125
ceil_mode = False ,
137
126
):
138
127
class TestModule (torch .nn .Module ):
139
- def __init__ (self ):
140
- super ().__init__ ()
141
- self .pool = torch .nn .MaxPool1d (
142
- kernel_size , stride , padding , dilation , return_indices , ceil_mode
143
- )
144
-
145
128
def forward (self , x ):
146
- return self .pool (x )
129
+ return torch .ops .aten .max_pool1d .default (
130
+ x , kernel_size , stride , padding , dilation , ceil_mode
131
+ )
147
132
148
133
inputs = [torch .randn (1 , 3 , 32 )]
149
134
self .run_test (
@@ -157,7 +142,7 @@ def forward(self, x):
157
142
[
158
143
(3 , 1 , 0 ),
159
144
(3 , 1 , 1 ),
160
- ((2 , 2 ), None , (1 , 0 )),
145
+ ((2 , 2 ), [] , (1 , 0 )),
161
146
((4 , 3 ), (1 , 1 ), (1 , 1 )),
162
147
((5 , 4 ), (2 , 1 ), (1 , 0 )),
163
148
((7 , 7 ), (1 , 2 ), (0 , 1 )),
@@ -169,32 +154,27 @@ def test_max_pool2d(
169
154
stride ,
170
155
padding ,
171
156
dilation = 1 ,
172
- return_indices = False ,
173
157
ceil_mode = False ,
174
158
):
175
159
class TestModule (torch .nn .Module ):
176
- def __init__ (self ):
177
- super ().__init__ ()
178
- self .pool = torch .nn .MaxPool2d (
179
- kernel_size ,
180
- stride ,
181
- padding ,
182
- dilation ,
183
- return_indices ,
184
- ceil_mode ,
185
- )
186
-
187
160
def forward (self , x ):
188
- return self .pool (x )
161
+ return torch .ops .aten .max_pool2d .default (
162
+ x , kernel_size , stride , padding , dilation , ceil_mode
163
+ )
189
164
190
165
inputs = [torch .randn (1 , 3 , 32 , 32 )]
191
- self .run_test (TestModule (), inputs , use_dynamo_tracer = True , enable_passes = True )
166
+ self .run_test (
167
+ TestModule (),
168
+ inputs ,
169
+ use_dynamo_tracer = True ,
170
+ enable_passes = True ,
171
+ )
192
172
193
173
@parameterized .expand (
194
174
[
195
175
(3 , 1 , 0 ),
196
176
(3 , 1 , 1 ),
197
- ((2 , 2 , 3 ), None , (1 , 0 , 1 )),
177
+ ((2 , 2 , 3 ), [] , (1 , 0 , 1 )),
198
178
((4 , 3 , 2 ), (1 , 1 , 1 ), (1 , 1 , 0 )),
199
179
((5 , 4 , 3 ), (2 , 1 , 2 ), (1 , 0 , 1 )),
200
180
((7 , 7 , 7 ), (1 , 2 , 1 ), (0 , 1 , 1 )),
@@ -206,26 +186,21 @@ def test_max_pool3d(
206
186
stride ,
207
187
padding ,
208
188
dilation = 1 ,
209
- return_indices = False ,
210
189
ceil_mode = False ,
211
190
):
212
191
class TestModule (torch .nn .Module ):
213
- def __init__ (self ):
214
- super ().__init__ ()
215
- self .pool = torch .nn .MaxPool3d (
216
- kernel_size ,
217
- stride ,
218
- padding ,
219
- dilation ,
220
- return_indices ,
221
- ceil_mode ,
222
- )
223
-
224
192
def forward (self , x ):
225
- return self .pool (x )
193
+ return torch .ops .aten .max_pool3d .default (
194
+ x , kernel_size , stride , padding , dilation , ceil_mode
195
+ )
226
196
227
197
inputs = [torch .randn (1 , 3 , 32 , 32 , 32 )]
228
- self .run_test (TestModule (), inputs , use_dynamo_tracer = True , enable_passes = True )
198
+ self .run_test (
199
+ TestModule (),
200
+ inputs ,
201
+ use_dynamo_tracer = True ,
202
+ enable_passes = True ,
203
+ )
229
204
230
205
231
206
if __name__ == "__main__" :
0 commit comments