Skip to content

Commit d17decb

Browse files
committed
Fixed formatting and more updates according to the review
1 parent 9d6ac74 commit d17decb

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype
146146

147147

148148
def make_segmentation_masks(
149-
image_sizes=((32, 32), (32, 42), (38, 24)),
149+
image_sizes=((16, 16), (7, 33), (31, 9)),
150150
dtypes=(torch.long,),
151151
extra_dims=((), (4,), (2, 3)),
152152
):
@@ -485,7 +485,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
485485
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
486486
return expected_mask.to(mask.device)
487487

488-
for mask in make_segmentation_masks(extra_dims=((), (4, ))):
488+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
489489
output_mask = F.affine_segmentation_mask(
490490
mask,
491491
angle=angle,
@@ -515,15 +515,13 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
515515

516516
@pytest.mark.parametrize("device", cpu_and_gpu())
517517
def test_correctness_affine_segmentation_mask_on_fixed_input(device):
518-
# Check transformation against known expected output
518+
# Check transformation against known expected output and CPU/CUDA devices
519519

520-
# Create a fixed input segmentation mask with 4 square masks
521-
# in top-left, top-right, bottom-right corners and in the center
520+
# Create a fixed input segmentation mask with 2 square masks
521+
# in top-left, bottom-left corners
522522
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
523523
mask[0, 2:10, 2:10] = 1
524524
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
525-
mask[0, 1:11, 32 - 11 : 32 - 1] = 3
526-
mask[0, 16 - 4 : 16 + 4, 16 - 4 : 16 + 4] = 4
527525

528526
# Rotate 90 degrees and scale
529527
expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1))

0 commit comments

Comments
 (0)