@@ -146,7 +146,7 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype
146
146
147
147
148
148
def make_segmentation_masks (
149
- image_sizes = ((32 , 32 ), (32 , 42 ), (38 , 24 )),
149
+ image_sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 )),
150
150
dtypes = (torch .long ,),
151
151
extra_dims = ((), (4 ,), (2 , 3 )),
152
152
):
@@ -485,7 +485,7 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
485
485
expected_mask [0 , out_y , out_x ] = mask [0 , in_y , in_x ]
486
486
return expected_mask .to (mask .device )
487
487
488
- for mask in make_segmentation_masks (extra_dims = ((), (4 , ))):
488
+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
489
489
output_mask = F .affine_segmentation_mask (
490
490
mask ,
491
491
angle = angle ,
@@ -515,15 +515,13 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
515
515
516
516
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
517
517
def test_correctness_affine_segmentation_mask_on_fixed_input (device ):
518
- # Check transformation against known expected output
518
+ # Check transformation against known expected output and CPU/CUDA devices
519
519
520
- # Create a fixed input segmentation mask with 4 square masks
521
- # in top-left, top-right, bottom-right corners and in the center
520
+ # Create a fixed input segmentation mask with 2 square masks
521
+ # in top-left, bottom-left corners
522
522
mask = torch .zeros (1 , 32 , 32 , dtype = torch .long , device = device )
523
523
mask [0 , 2 :10 , 2 :10 ] = 1
524
524
mask [0 , 32 - 9 : 32 - 3 , 3 :9 ] = 2
525
- mask [0 , 1 :11 , 32 - 11 : 32 - 1 ] = 3
526
- mask [0 , 16 - 4 : 16 + 4 , 16 - 4 : 16 + 4 ] = 4
527
525
528
526
# Rotate 90 degrees and scale
529
527
expected_mask = torch .rot90 (mask , k = - 1 , dims = (- 2 , - 1 ))
0 commit comments