Skip to content

Commit d5e2c36

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Fix prototype transforms for (*, H, W) segmentation masks (#6574)
Summary: * add generator functions for segmentation masks * update functional tests and fix kernels * fix transforms tests Reviewed By: jdsgomes Differential Revision: D39543270 fbshipit-source-id: 6f56ba288277ce0dd98115a794b2002d5341e83f
1 parent f3f64a8 commit d5e2c36

File tree

5 files changed

+166
-42
lines changed

5 files changed

+166
-42
lines changed

test/prototype_common_utils.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,58 @@ def make_one_hot_labels(
178178
yield make_one_hot_label(extra_dims_)
179179

180180

181-
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
181+
def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
182+
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
182183
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
183184
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
184185
shape = (*extra_dims, num_objects, *size)
185186
data = make_tensor(shape, low=0, high=2, dtype=dtype)
186187
return features.SegmentationMask(data)
187188

188189

190+
def make_detection_masks(
191+
*,
192+
sizes=((16, 16), (7, 33), (31, 9)),
193+
dtypes=(torch.uint8,),
194+
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
195+
num_objects=(1, 0, None),
196+
):
197+
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
198+
yield make_detection_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
199+
200+
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
201+
yield make_detection_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
202+
203+
204+
def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8):
205+
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
206+
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
207+
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
208+
shape = (*extra_dims, *size)
209+
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
210+
return features.SegmentationMask(data)
211+
212+
189213
def make_segmentation_masks(
214+
*,
190215
sizes=((16, 16), (7, 33), (31, 9)),
191216
dtypes=(torch.uint8,),
192217
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
193-
num_objects=(1, 0, 10),
218+
num_categories=(1, 2, None),
194219
):
195220
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
196221
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
197222

198-
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
199-
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
223+
for dtype, extra_dims_, num_categories_ in itertools.product(dtypes, extra_dims, num_categories):
224+
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
225+
226+
227+
def make_detection_and_segmentation_masks(
228+
sizes=((16, 16), (7, 33), (31, 9)),
229+
dtypes=(torch.uint8,),
230+
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
231+
num_objects=(1, 0, None),
232+
num_categories=(1, 2, None),
233+
):
234+
yield from make_detection_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_objects=num_objects)
235+
yield from make_segmentation_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_categories=num_categories)

test/test_prototype_transforms.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from prototype_common_utils import (
1111
make_bounding_box,
1212
make_bounding_boxes,
13+
make_detection_and_segmentation_masks,
14+
make_detection_mask,
1315
make_image,
1416
make_images,
1517
make_label,
@@ -62,6 +64,7 @@ def parametrize_from_transforms(*transforms):
6264
make_one_hot_labels,
6365
make_vanilla_tensor_images,
6466
make_pil_images,
67+
make_detection_and_segmentation_masks,
6568
]:
6669
inputs = list(creation_fn())
6770
try:
@@ -131,7 +134,12 @@ def test_mixup_cutmix(self, transform, input):
131134
# Check if we raise an error if sample contains bbox or mask or label
132135
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
133136
input_copy = dict(input)
134-
for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]:
137+
for unsup_data in [
138+
make_label(),
139+
make_bounding_box(format="XYXY"),
140+
make_detection_mask(),
141+
make_segmentation_mask(),
142+
]:
135143
input_copy["unsupported"] = unsup_data
136144
with pytest.raises(TypeError, match=err_msg):
137145
transform(input_copy)
@@ -233,7 +241,7 @@ def test_convert_color_space_unsupported_types(self):
233241
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
234242
)
235243

236-
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
244+
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
237245
output = transform(inpt)
238246
assert output is inpt
239247

@@ -1206,7 +1214,7 @@ def test__transform(self, mocker):
12061214
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
12071215
label = features.Label(torch.randint(0, 10, size=(6,)))
12081216
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
1209-
masks = make_segmentation_mask((32, 24), num_objects=6)
1217+
masks = make_detection_mask((32, 24), num_objects=6)
12101218

12111219
sample = [image, bboxes, label, ohe_label, masks]
12121220

@@ -1578,7 +1586,7 @@ def test__transform_culling(self, mocker):
15781586
bounding_boxes = make_bounding_box(
15791587
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
15801588
)
1581-
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
1589+
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
15821590
labels = make_label(size=(batch_size,))
15831591

15841592
transform = transforms.FixedSizeCrop((-1, -1))

test/test_prototype_transforms_functional.py

+33-19
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
import torch.testing
99
import torchvision.prototype.transforms.functional as F
1010
from common_utils import cpu_and_gpu
11-
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks
11+
from prototype_common_utils import (
12+
ArgsKwargs,
13+
make_bounding_boxes,
14+
make_detection_and_segmentation_masks,
15+
make_detection_masks,
16+
make_image,
17+
make_images,
18+
)
1219
from torch import jit
1320
from torchvision.prototype import features
1421
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box():
5562

5663
@register_kernel_info_from_sample_inputs_fn
5764
def horizontal_flip_segmentation_mask():
58-
for mask in make_segmentation_masks():
65+
for mask in make_detection_and_segmentation_masks():
5966
yield ArgsKwargs(mask)
6067

6168

@@ -73,7 +80,7 @@ def vertical_flip_bounding_box():
7380

7481
@register_kernel_info_from_sample_inputs_fn
7582
def vertical_flip_segmentation_mask():
76-
for mask in make_segmentation_masks():
83+
for mask in make_detection_and_segmentation_masks():
7784
yield ArgsKwargs(mask)
7885

7986

@@ -118,7 +125,7 @@ def resize_bounding_box():
118125
@register_kernel_info_from_sample_inputs_fn
119126
def resize_segmentation_mask():
120127
for mask, max_size in itertools.product(
121-
make_segmentation_masks(),
128+
make_detection_and_segmentation_masks(),
122129
[None, 34], # max_size
123130
):
124131
height, width = mask.shape[-2:]
@@ -173,7 +180,7 @@ def affine_bounding_box():
173180
@register_kernel_info_from_sample_inputs_fn
174181
def affine_segmentation_mask():
175182
for mask, angle, translate, scale, shear in itertools.product(
176-
make_segmentation_masks(),
183+
make_detection_and_segmentation_masks(),
177184
[-87, 15, 90], # angle
178185
[5, -5], # translate
179186
[0.77, 1.27], # scale
@@ -226,7 +233,7 @@ def rotate_bounding_box():
226233
@register_kernel_info_from_sample_inputs_fn
227234
def rotate_segmentation_mask():
228235
for mask, angle, expand, center in itertools.product(
229-
make_segmentation_masks(),
236+
make_detection_and_segmentation_masks(),
230237
[-87, 15, 90], # angle
231238
[True, False], # expand
232239
[None, [12, 23]], # center
@@ -269,7 +276,7 @@ def crop_bounding_box():
269276
@register_kernel_info_from_sample_inputs_fn
270277
def crop_segmentation_mask():
271278
for mask, top, left, height, width in itertools.product(
272-
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
279+
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
273280
):
274281
yield ArgsKwargs(
275282
mask,
@@ -307,7 +314,7 @@ def resized_crop_bounding_box():
307314
@register_kernel_info_from_sample_inputs_fn
308315
def resized_crop_segmentation_mask():
309316
for mask, top, left, height, width, size in itertools.product(
310-
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
317+
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
311318
):
312319
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
313320

@@ -326,7 +333,7 @@ def pad_image_tensor():
326333
@register_kernel_info_from_sample_inputs_fn
327334
def pad_segmentation_mask():
328335
for mask, padding, padding_mode in itertools.product(
329-
make_segmentation_masks(),
336+
make_detection_and_segmentation_masks(),
330337
[[1], [1, 1], [1, 1, 2, 2]], # padding
331338
["constant", "symmetric", "edge", "reflect"], # padding mode,
332339
):
@@ -374,7 +381,7 @@ def perspective_bounding_box():
374381
@register_kernel_info_from_sample_inputs_fn
375382
def perspective_segmentation_mask():
376383
for mask, perspective_coeffs in itertools.product(
377-
make_segmentation_masks(extra_dims=((), (4,))),
384+
make_detection_and_segmentation_masks(extra_dims=((), (4,))),
378385
[
379386
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
380387
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
@@ -411,7 +418,7 @@ def elastic_bounding_box():
411418

412419
@register_kernel_info_from_sample_inputs_fn
413420
def elastic_segmentation_mask():
414-
for mask in make_segmentation_masks(extra_dims=((), (4,))):
421+
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):
415422
h, w = mask.shape[-2:]
416423
displacement = torch.rand(1, h, w, 2)
417424
yield ArgsKwargs(
@@ -440,7 +447,7 @@ def center_crop_bounding_box():
440447
@register_kernel_info_from_sample_inputs_fn
441448
def center_crop_segmentation_mask():
442449
for mask, output_size in itertools.product(
443-
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
450+
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
444451
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
445452
):
446453
yield ArgsKwargs(mask, output_size)
@@ -771,7 +778,8 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
771778
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
772779
return expected_mask.to(mask.device)
773780

774-
for mask in make_segmentation_masks(extra_dims=((), (4,))):
781+
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
782+
for mask in make_detection_masks(extra_dims=((), (4,))):
775783
output_mask = F.affine_segmentation_mask(
776784
mask,
777785
angle=angle,
@@ -1011,7 +1019,8 @@ def _compute_expected_mask(mask, angle_, expand_, center_):
10111019
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
10121020
return expected_mask.to(mask.device)
10131021

1014-
for mask in make_segmentation_masks(extra_dims=((), (4,))):
1022+
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
1023+
for mask in make_detection_masks(extra_dims=((), (4,))):
10151024
output_mask = F.rotate_segmentation_mask(
10161025
mask,
10171026
angle=angle,
@@ -1138,7 +1147,7 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
11381147

11391148
return expected
11401149

1141-
for mask in make_segmentation_masks():
1150+
for mask in make_detection_and_segmentation_masks():
11421151
if mask.device != torch.device(device):
11431152
mask = mask.to(device)
11441153
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
@@ -1358,7 +1367,7 @@ def _compute_expected_mask(mask, padding_, padding_mode_):
13581367

13591368
return output
13601369

1361-
for mask in make_segmentation_masks():
1370+
for mask in make_detection_and_segmentation_masks():
13621371
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
13631372

13641373
expected_mask = _compute_expected_mask(mask, padding, padding_mode)
@@ -1487,7 +1496,8 @@ def _compute_expected_mask(mask, pcoeffs_):
14871496

14881497
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
14891498

1490-
for mask in make_segmentation_masks(extra_dims=((), (4,))):
1499+
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
1500+
for mask in make_detection_masks(extra_dims=((), (4,))):
14911501
mask = mask.to(device)
14921502

14931503
output_mask = F.perspective_segmentation_mask(
@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
16491659

16501660
@pytest.mark.parametrize("device", cpu_and_gpu())
16511661
@pytest.mark.parametrize(
1652-
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
1662+
"fn, make_samples",
1663+
[
1664+
(F.elastic_image_tensor, make_images),
1665+
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
1666+
(F.elastic_segmentation_mask, make_detection_masks),
1667+
],
16531668
)
16541669
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
16551670
in_box = [10, 15, 25, 35]
16561671
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
16571672
c, h, w = sample.shape[-3:]
16581673
# Setup a dummy image with 4 points
1659-
print(sample.shape)
16601674
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c)
16611675
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
16621676
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c)

test/test_prototype_transforms_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
6+
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
77

88
from torchvision.prototype import features
99
from torchvision.prototype.transforms._utils import has_all, has_any
@@ -12,7 +12,7 @@
1212

1313
IMAGE = make_image(color_space=features.ColorSpace.RGB)
1414
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
15-
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
15+
SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size)
1616

1717

1818
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)