diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py index 93f2094184..ab83a304fd 100644 --- a/tests/py/dynamo/conversion/test_pool_aten.py +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -10,11 +10,11 @@ class TestPoolConverter(DispatchTestCase): @parameterized.expand( [ (3, 1, 0), - (3, 1, 1), - (2, None, 0), - (4, 1, 1), - (5, 2, 0), - (7, 2, 1), + ((3,), (1,), (1,)), + ((2,), [], (0,)), + ((4,), (1,), (1,)), + ((5,), (2,), (0,)), + ((7,), (2,), (1,)), ] ) def test_avg_pool1d( @@ -26,14 +26,10 @@ def test_avg_pool1d( count_include_pad=True, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool1d( - kernel_size, stride, padding, ceil_mode, count_include_pad - ) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.avg_pool1d.default( + x, kernel_size, stride, padding, ceil_mode, count_include_pad + ) inputs = [torch.randn(1, 3, 32)] self.run_test( @@ -46,7 +42,7 @@ def forward(self, x): [ (3, 1, 0), (3, 1, 1), - ((2, 2), None, (1, 0)), + ((2, 2), [], (1, 0)), ((4, 3), (1, 1), (1, 1)), ((5, 4), (2, 1), (1, 0)), ((7, 7), (1, 2), (0, 1)), @@ -62,9 +58,9 @@ def test_avg_pool2d( divisor_override=None, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool2d( + def forward(self, x): + return torch.ops.aten.avg_pool2d.default( + x, kernel_size, stride, padding, @@ -73,9 +69,6 @@ def __init__(self): divisor_override, ) - def forward(self, x): - return self.pool(x) - inputs = [torch.randn(1, 3, 32, 32)] self.run_test(TestModule(), inputs, use_dynamo_tracer=True) @@ -83,7 +76,7 @@ def forward(self, x): [ (3, 1, 0), (3, 1, 1), - ((2, 2, 3), None, (1, 0, 1)), + ((2, 2, 3), [], (1, 0, 1)), ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), @@ -99,9 +92,9 @@ def test_avg_pool3d( divisor_override=None, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AvgPool3d( + def forward(self, x): + return torch.ops.aten.avg_pool3d.default( + x, kernel_size, stride, padding, @@ -110,20 +103,17 @@ def __init__(self): divisor_override, ) - def forward(self, x): - return self.pool(x) - inputs = [torch.randn(1, 3, 32, 32, 32)] self.run_test(TestModule(), inputs, use_dynamo_tracer=True) @parameterized.expand( [ (3, 1, 0), - (3, 1, 1), - (2, None, 0), - (4, 1, 1), - (5, 2, 0), - (7, 2, 1), + ((3,), (1,), (1,)), + ((2,), [], (0,)), + ((4,), (1,), (1,)), + ((5,), (2,), (0,)), + ((7,), (2,), (1,)), ] ) def test_max_pool1d( @@ -132,18 +122,13 @@ def test_max_pool1d( stride, padding, dilation=1, - return_indices=False, ceil_mode=False, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool1d( - kernel_size, stride, padding, dilation, return_indices, ceil_mode - ) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.max_pool1d.default( + x, kernel_size, stride, padding, dilation, ceil_mode + ) inputs = [torch.randn(1, 3, 32)] self.run_test( @@ -157,7 +142,7 @@ def forward(self, x): [ (3, 1, 0), (3, 1, 1), - ((2, 2), None, (1, 0)), + ((2, 2), [], (1, 0)), ((4, 3), (1, 1), (1, 1)), ((5, 4), (2, 1), (1, 0)), ((7, 7), (1, 2), (0, 1)), @@ -169,32 +154,27 @@ def test_max_pool2d( stride, padding, dilation=1, - return_indices=False, ceil_mode=False, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool2d( - kernel_size, - stride, - padding, - dilation, - return_indices, - ceil_mode, - ) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.max_pool2d.default( + x, kernel_size, stride, padding, dilation, ceil_mode + ) inputs = [torch.randn(1, 3, 32, 32)] - self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True) + self.run_test( + TestModule(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) @parameterized.expand( [ (3, 1, 0), (3, 1, 1), - ((2, 2, 3), None, (1, 0, 1)), + ((2, 2, 3), [], (1, 0, 1)), ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), @@ -206,26 +186,21 @@ def test_max_pool3d( stride, padding, dilation=1, - return_indices=False, ceil_mode=False, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.MaxPool3d( - kernel_size, - stride, - padding, - dilation, - return_indices, - ceil_mode, - ) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.max_pool3d.default( + x, kernel_size, stride, padding, dilation, ceil_mode + ) inputs = [torch.randn(1, 3, 32, 32, 32)] - self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True) + self.run_test( + TestModule(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) if __name__ == "__main__":