Skip to content

fix: update tests of pooling converters #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 44 additions & 69 deletions tests/py/dynamo/conversion/test_pool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class TestPoolConverter(DispatchTestCase):
@parameterized.expand(
[
(3, 1, 0),
(3, 1, 1),
(2, None, 0),
(4, 1, 1),
(5, 2, 0),
(7, 2, 1),
((3,), (1,), (1,)),
((2,), [], (0,)),
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
]
)
def test_avg_pool1d(
Expand All @@ -26,14 +26,10 @@ def test_avg_pool1d(
count_include_pad=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool1d(
kernel_size, stride, padding, ceil_mode, count_include_pad
)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.avg_pool1d.default(
x, kernel_size, stride, padding, ceil_mode, count_include_pad
)

inputs = [torch.randn(1, 3, 32)]
self.run_test(
Expand All @@ -46,7 +42,7 @@ def forward(self, x):
[
(3, 1, 0),
(3, 1, 1),
((2, 2), None, (1, 0)),
((2, 2), [], (1, 0)),
((4, 3), (1, 1), (1, 1)),
((5, 4), (2, 1), (1, 0)),
((7, 7), (1, 2), (0, 1)),
Expand All @@ -62,9 +58,9 @@ def test_avg_pool2d(
divisor_override=None,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool2d(
def forward(self, x):
return torch.ops.aten.avg_pool2d.default(
x,
kernel_size,
stride,
padding,
Expand All @@ -73,17 +69,14 @@ def __init__(self):
divisor_override,
)

def forward(self, x):
return self.pool(x)

inputs = [torch.randn(1, 3, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)

@parameterized.expand(
[
(3, 1, 0),
(3, 1, 1),
((2, 2, 3), None, (1, 0, 1)),
((2, 2, 3), [], (1, 0, 1)),
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
Expand All @@ -99,9 +92,9 @@ def test_avg_pool3d(
divisor_override=None,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool3d(
def forward(self, x):
return torch.ops.aten.avg_pool3d.default(
x,
kernel_size,
stride,
padding,
Expand All @@ -110,20 +103,17 @@ def __init__(self):
divisor_override,
)

def forward(self, x):
return self.pool(x)

inputs = [torch.randn(1, 3, 32, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True)

@parameterized.expand(
[
(3, 1, 0),
(3, 1, 1),
(2, None, 0),
(4, 1, 1),
(5, 2, 0),
(7, 2, 1),
((3,), (1,), (1,)),
((2,), [], (0,)),
((4,), (1,), (1,)),
((5,), (2,), (0,)),
((7,), (2,), (1,)),
]
)
def test_max_pool1d(
Expand All @@ -132,18 +122,13 @@ def test_max_pool1d(
stride,
padding,
dilation=1,
return_indices=False,
ceil_mode=False,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool1d(
kernel_size, stride, padding, dilation, return_indices, ceil_mode
)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.max_pool1d.default(
x, kernel_size, stride, padding, dilation, ceil_mode
)

inputs = [torch.randn(1, 3, 32)]
self.run_test(
Expand All @@ -157,7 +142,7 @@ def forward(self, x):
[
(3, 1, 0),
(3, 1, 1),
((2, 2), None, (1, 0)),
((2, 2), [], (1, 0)),
((4, 3), (1, 1), (1, 1)),
((5, 4), (2, 1), (1, 0)),
((7, 7), (1, 2), (0, 1)),
Expand All @@ -169,32 +154,27 @@ def test_max_pool2d(
stride,
padding,
dilation=1,
return_indices=False,
ceil_mode=False,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(
kernel_size,
stride,
padding,
dilation,
return_indices,
ceil_mode,
)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.max_pool2d.default(
x, kernel_size, stride, padding, dilation, ceil_mode
)

inputs = [torch.randn(1, 3, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True)
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

@parameterized.expand(
[
(3, 1, 0),
(3, 1, 1),
((2, 2, 3), None, (1, 0, 1)),
((2, 2, 3), [], (1, 0, 1)),
((4, 3, 2), (1, 1, 1), (1, 1, 0)),
((5, 4, 3), (2, 1, 2), (1, 0, 1)),
((7, 7, 7), (1, 2, 1), (0, 1, 1)),
Expand All @@ -206,26 +186,21 @@ def test_max_pool3d(
stride,
padding,
dilation=1,
return_indices=False,
ceil_mode=False,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool3d(
kernel_size,
stride,
padding,
dilation,
return_indices,
ceil_mode,
)

def forward(self, x):
return self.pool(x)
return torch.ops.aten.max_pool3d.default(
x, kernel_size, stride, padding, dilation, ceil_mode
)

inputs = [torch.randn(1, 3, 32, 32, 32)]
self.run_test(TestModule(), inputs, use_dynamo_tracer=True, enable_passes=True)
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)


if __name__ == "__main__":
Expand Down