Skip to content

Fix prototype transforms for (*, H, W) segmentation masks #6574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2022

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Sep 13, 2022

Most of our *_segmentation_mask kernels simply dispatch to the *_image_tensor kernel. However a segmentation mask (from the segmentation task) might have no channel dimension, i.e. shape (*, H, W). Passing this directly into the image kernel will fail for some that require this dimension to be there. For example

num_channels, height, width = img.shape[-3:]

This was not detected before because the CI only tested against "detection" masks, i.e. (*; N, H, W) where N denotes the number of objects.

All kernels work the same, but in case the input has no channel dimension, we .unsqueeze(0) before calling the image kernel and .squeeze(0) afterwards.

@@ -178,22 +178,58 @@ def make_one_hot_labels(
yield make_one_hot_label(extra_dims_)


def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming is now quite weird. We are calling a make_detection_mask function and get a features.SegmentationMask back. Now that we have somewhat settled on the terms segmentation and detection masks, can we maybe rename the class to features.Mask since it represents both types?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought, I remember @YosuaMichael telling us that on Depth Perception they also have a sort of Valid Masks (see #6495).

@YosuaMichael could you clarify if they operate exactly like Segmentation/Detection masks and if we could just reuse this object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give more context:

  • segmentation mask is of shape (*, H, W) where each value corresponds to a class index
  • detection mask is of shape (*, N, H, W) where N denotes the number of objects and the values are otherwise "boolean"

In both cases the * in the shape denotes arbitrary batch dimensions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I think the valid_mask on depth perception will behave like detection_mask with 2 object (it is boolean, but can be implemented with 0/1 uint as well). So overall I think we can reuse it.

Comment on lines +781 to +782
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to the assert mask.ndim == 3. This was introduced in #6546 to make sure only detection masks are computed here. @vfdev-5 Could you fix the expected mask computation so both types of masks are supported? If that is done, we should use

Suggested change
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):

here.

This comment also applies to other occurrences of the same FIXME comment.

angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
mask,
if segmentation_mask.ndim < 3:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to do that in a decorator, but torch.jit.script doesn't support it 😢

def handle_segmentation_mask(mask_kernel):
    @functools.wraps(mask_kernel)
    def wrapper(mask, *other_args, **kwargs):
        if mask.ndim < 3:
            mask = mask.unsqueeze(0)
            needs_squeeze = True
        else:
            needs_squeeze = False

        output = mask_kernel(mask, *other_args, **kwargs)

        if needs_squeeze:
            output = output.squeeze(0)

        return output

    return wrapper


@handle_segmentation_mask
def affine_segmentation_mask(mask, ...):
    ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The story of our lives...

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes on the functional looks good to me.

angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
mask,
if segmentation_mask.ndim < 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The story of our lives...

@pmeier
Copy link
Collaborator Author

pmeier commented Sep 13, 2022

Test failure is unrelated:

____________ test_eager_vs_scripted[gaussian_blur_image_tensor-50] _____________
Traceback (most recent call last):
  File "/home/runner/work/vision/vision/test/test_prototype_transforms_functional.py", line 602, in test_eager_vs_scripted
    torch.testing.assert_close(eager, scripted)
  File "/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1359, in assert_close
    msg=msg,
  File "/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 1 / 1024 (0.1%)
Greatest absolute difference: 1 at index (1, 0, 15, 4)
Greatest relative difference: 0.006172839552164078 at index (1, 0, 15, 4)

@pmeier pmeier merged commit 2c19af3 into pytorch:main Sep 13, 2022
@pmeier pmeier deleted the segmentation-masks branch September 13, 2022 13:47
facebook-github-bot pushed a commit that referenced this pull request Sep 15, 2022
…6574)

Summary:
* add generator functions for segmentation masks

* update functional tests and fix kernels

* fix transforms tests

Reviewed By: jdsgomes

Differential Revision: D39543270

fbshipit-source-id: 6f56ba288277ce0dd98115a794b2002d5341e83f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants