Skip to content

Fix prototype transforms for (*, H, W) segmentation masks #6574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,22 +178,58 @@ def make_one_hot_labels(
yield make_one_hot_label(extra_dims_)


def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is now quite weird. We are calling a make_detection_mask function and get a features.SegmentationMask back. Now that we have somewhat settled on the terms segmentation and detection masks, can we maybe rename the class to features.Mask since it represents both types?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought, I remember @YosuaMichael telling us that on Depth Perception they also have a sort of Valid Masks (see #6495).

@YosuaMichael could you clarify if they operate exactly like Segmentation/Detection masks and if we could just reuse this object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give more context:

  • segmentation mask is of shape (*, H, W) where each value corresponds to a class index
  • detection mask is of shape (*, N, H, W) where N denotes the number of objects and the values are otherwise "boolean"

In both cases the * in the shape denotes arbitrary batch dimensions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I think the valid_mask on depth perception will behave like detection_mask with 2 object (it is boolean, but can be implemented with 0/1 uint as well). So overall I think we can reuse it.

# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)


def make_detection_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_detection_mask(size=size, dtype=dtype, extra_dims=extra_dims_)

for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_detection_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)


def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data)


def make_segmentation_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
num_categories=(1, 2, None),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)

for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_categories_ in itertools.product(dtypes, extra_dims, num_categories):
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)


def make_detection_and_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
num_categories=(1, 2, None),
):
yield from make_detection_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_objects=num_objects)
yield from make_segmentation_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_categories=num_categories)
16 changes: 12 additions & 4 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from prototype_common_utils import (
make_bounding_box,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_mask,
make_image,
make_images,
make_label,
Expand Down Expand Up @@ -62,6 +64,7 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_detection_and_segmentation_masks,
]:
inputs = list(creation_fn())
try:
Expand Down Expand Up @@ -131,7 +134,12 @@ def test_mixup_cutmix(self, transform, input):
# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
input_copy = dict(input)
for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]:
for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg):
transform(input_copy)
Expand Down Expand Up @@ -233,7 +241,7 @@ def test_convert_color_space_unsupported_types(self):
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
)

for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
output = transform(inpt)
assert output is inpt

Expand Down Expand Up @@ -1206,7 +1214,7 @@ def test__transform(self, mocker):
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_segmentation_mask((32, 24), num_objects=6)
masks = make_detection_mask((32, 24), num_objects=6)

sample = [image, bboxes, label, ohe_label, masks]

Expand Down Expand Up @@ -1578,7 +1586,7 @@ def test__transform_culling(self, mocker):
bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
Expand Down
52 changes: 33 additions & 19 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import torch.testing
import torchvision.prototype.transforms.functional as F
from common_utils import cpu_and_gpu
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks
from prototype_common_utils import (
ArgsKwargs,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_masks,
make_image,
make_images,
)
from torch import jit
from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
Expand Down Expand Up @@ -55,7 +62,7 @@ def horizontal_flip_bounding_box():

@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask)


Expand All @@ -73,7 +80,7 @@ def vertical_flip_bounding_box():

@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask)


Expand Down Expand Up @@ -118,7 +125,7 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask():
for mask, max_size in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[None, 34], # max_size
):
height, width = mask.shape[-2:]
Expand Down Expand Up @@ -173,7 +180,7 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
Expand Down Expand Up @@ -226,7 +233,7 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
Expand Down Expand Up @@ -269,7 +276,7 @@ def crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield ArgsKwargs(
mask,
Expand Down Expand Up @@ -307,7 +314,7 @@ def resized_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)

Expand All @@ -326,7 +333,7 @@ def pad_image_tensor():
@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
Expand Down Expand Up @@ -374,7 +381,7 @@ def perspective_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask():
for mask, perspective_coeffs in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
make_detection_and_segmentation_masks(extra_dims=((), (4,))),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
Expand Down Expand Up @@ -411,7 +418,7 @@ def elastic_bounding_box():

@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs(
Expand Down Expand Up @@ -440,7 +447,7 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield ArgsKwargs(mask, output_size)
Expand Down Expand Up @@ -771,7 +778,8 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device)

for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
Comment on lines +781 to +782
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to the assert mask.ndim == 3. This was introduced in #6546 to make sure only detection masks are computed here. @vfdev-5 Could you fix the expected mask computation so both types of masks are supported? If that is done, we should use

Suggested change
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):

here.

This comment also applies to other occurrences of the same FIXME comment.

output_mask = F.affine_segmentation_mask(
mask,
angle=angle,
Expand Down Expand Up @@ -1011,7 +1019,8 @@ def _compute_expected_mask(mask, angle_, expand_, center_):
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device)

for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask(
mask,
angle=angle,
Expand Down Expand Up @@ -1138,7 +1147,7 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):

return expected

for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
if mask.device != torch.device(device):
mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
Expand Down Expand Up @@ -1358,7 +1367,7 @@ def _compute_expected_mask(mask, padding_, padding_mode_):

return output

for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)

expected_mask = _compute_expected_mask(mask, padding, padding_mode)
Expand Down Expand Up @@ -1487,7 +1496,8 @@ def _compute_expected_mask(mask, pcoeffs_):

pcoeffs = _get_perspective_coeffs(startpoints, endpoints)

for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
mask = mask.to(device)

output_mask = F.perspective_segmentation_mask(
Expand Down Expand Up @@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
"fn, make_samples",
[
(F.elastic_image_tensor, make_images),
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
(F.elastic_segmentation_mask, make_detection_masks),
],
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points
print(sample.shape)
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c)
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c)
Expand Down
4 changes: 2 additions & 2 deletions test/test_prototype_transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image

from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any
Expand All @@ -12,7 +12,7 @@

IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size)


@pytest.mark.parametrize(
Expand Down
Loading